deepliif 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,792 +0,0 @@
1
- """This package contains modules related to objective functions, optimizations, and network architectures.
2
-
3
- To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
- You need to implement the following five functions:
5
- -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
- -- <set_input>: unpack data from dataset and apply preprocessing.
7
- -- <forward>: produce intermediate results.
8
- -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
- -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
-
11
- In the function <__init__>, you need to define four lists:
12
- -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
- -- self.model_names (str list): define networks used in our training.
14
- -- self.visual_names (str list): specify the images that you want to display and save.
15
- -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See CycleGAN_model.py for an usage.
16
-
17
- Now you can use the model class by specifying flag '--model dummy'.
18
- See our template model class 'template_model.py' for more details.
19
- """
20
- import base64
21
- import os
22
- import itertools
23
- import importlib
24
- from functools import lru_cache
25
- from io import BytesIO
26
- import json
27
- import math
28
-
29
- import requests
30
- import torch
31
- from PIL import Image
32
- Image.MAX_IMAGE_PIXELS = None
33
-
34
- import numpy as np
35
- from dask import delayed, compute
36
-
37
- from deepliif.util import *
38
- from deepliif.util.util import tensor_to_pil
39
- from deepliif.data import transform
40
- from deepliif.postprocessing import compute_final_results, compute_cell_results
41
- from deepliif.postprocessing import encode_cell_data_v4, decode_cell_data_v4
42
- from deepliif.options import Options, print_options
43
-
44
- from .base_model import BaseModel
45
-
46
- # import for init purpose, not used in this script
47
- from .DeepLIIF_model import DeepLIIFModel
48
- from .DeepLIIFExt_model import DeepLIIFExtModel
49
-
50
-
51
- import time
52
- time_g, time_gs, time_stack = 0, 0, 0
53
-
54
-
55
- @lru_cache
56
- def get_opt(model_dir, mode='test'):
57
- """
58
- mode: test or train, currently only functions used for inference utilize get_opt so it
59
- defaults to test
60
- """
61
- if mode == 'train':
62
- opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
63
- elif mode == 'test':
64
- try:
65
- opt = Options(path_file=os.path.join(model_dir,'test_opt.txt'), mode=mode)
66
- except:
67
- opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
68
- opt.use_dp = False
69
- opt.gpu_ids = list(range(torch.cuda.device_count()))
70
- return opt
71
-
72
-
73
- def find_model_using_name(model_name):
74
- """Import the module "models/[model_name]_model.py".
75
-
76
- In the file, the class called DatasetNameModel() will
77
- be instantiated. It has to be a subclass of BaseModel,
78
- and it is case-insensitive.
79
- """
80
- model_filename = "deepliif.models." + model_name + "_model"
81
- modellib = importlib.import_module(model_filename)
82
- model = None
83
- target_model_name = model_name.replace('_', '') + 'model'
84
- for name, cls in modellib.__dict__.items():
85
- if name.lower() == target_model_name.lower() \
86
- and issubclass(cls, BaseModel):
87
- model = cls
88
-
89
- if model is None:
90
- print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (
91
- model_filename, target_model_name))
92
- exit(0)
93
-
94
- return model
95
-
96
-
97
- def get_option_setter(model_name):
98
- """Return the static method <modify_commandline_options> of the model class."""
99
- model_class = find_model_using_name(model_name)
100
- return model_class.modify_commandline_options
101
-
102
-
103
- def create_model(opt):
104
- """Create a model given the option.
105
-
106
- This function warps the class CustomDatasetDataLoader.
107
- This is the main interface between this package and 'train.py'/'test.py'
108
-
109
- Example:
110
- >>> from deepliif.models import create_model
111
- >>> model = create_model(opt)
112
- """
113
- model = find_model_using_name(opt.model)
114
- instance = model(opt)
115
- print("model [%s] was created" % type(instance).__name__)
116
- return instance
117
-
118
-
119
- def load_torchscript_model(model_pt_path, device):
120
- net = torch.jit.load(model_pt_path, map_location=device)
121
- net = disable_batchnorm_tracking_stats(net)
122
- net.eval()
123
- return net
124
-
125
-
126
-
127
- def load_eager_models(opt, devices=None):
128
- # create a model given model and other options
129
- model = create_model(opt)
130
- # regular setup: load and print networks; create schedulers
131
- model.setup(opt)
132
-
133
- nets = {}
134
- if devices:
135
- model_names = list(devices.keys())
136
- else:
137
- model_names = model.model_names
138
-
139
- for name in model_names:#model.model_names:
140
- if isinstance(name, str):
141
- if '_' in name:
142
- net = getattr(model, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
143
- else:
144
- net = getattr(model, 'net' + name)
145
-
146
- if opt.phase != 'train':
147
- net.eval()
148
- net = disable_batchnorm_tracking_stats(net)
149
-
150
- # SDG models when loaded are still DP.. not sure why
151
- if isinstance(net, torch.nn.DataParallel):
152
- net = net.module
153
-
154
- nets[name] = net
155
- if devices:
156
- nets[name].to(devices[name])
157
-
158
- return nets
159
-
160
-
161
- @lru_cache
162
- def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
163
- """
164
- Init DeepLIIF networks so that every net in
165
- the same group is deployed on the same GPU
166
-
167
- opt_args: to overwrite opt arguments in train_opt.txt, typically used in inference stage
168
- for example, opt_args={'phase':'test'}
169
- """
170
- if opt is None:
171
- opt = get_opt(model_dir, mode=phase)
172
- opt.use_dp = False
173
-
174
- if opt.model == 'DeepLIIF':
175
- net_groups = [
176
- ('G1', 'G52'),
177
- ('G2', 'G53'),
178
- ('G3', 'G54'),
179
- ('G4', 'G55'),
180
- ('G51',)
181
- ]
182
- elif opt.model in ['DeepLIIFExt','SDG']:
183
- if opt.seg_gen:
184
- net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
185
- else:
186
- net_groups = [(f'G_{i+1}',) for i in range(opt.modalities_no)]
187
- elif opt.model == 'CycleGAN':
188
- if opt.BtoA:
189
- net_groups = [(f'GB_{i+1}',) for i in range(opt.modalities_no)]
190
- else:
191
- net_groups = [(f'GA_{i+1}',) for i in range(opt.modalities_no)]
192
- else:
193
- raise Exception(f'init_nets() not implemented for model {opt.model}')
194
-
195
- number_of_gpus_all = torch.cuda.device_count()
196
- number_of_gpus = min(len(opt.gpu_ids),number_of_gpus_all)
197
-
198
- if number_of_gpus > 0:
199
- mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
200
- chunks = [itertools.chain.from_iterable(c) for c in chunker(net_groups, number_of_gpus)]
201
- # chunks = chunks[1:]
202
- devices = {n: torch.device(f'cuda:{mapping_gpu_ids[i]}') for i, g in enumerate(chunks) for n in g}
203
- # devices = {n: torch.device(f'cuda:{i}') for i, g in enumerate(chunks) for n in g}
204
- else:
205
- devices = {n: torch.device('cpu') for n in itertools.chain.from_iterable(net_groups)}
206
-
207
- if eager_mode:
208
- return load_eager_models(opt, devices)
209
-
210
- return {
211
- n: load_torchscript_model(os.path.join(model_dir, f'{n}.pt'), device=d)
212
- for n, d in devices.items()
213
- }
214
-
215
-
216
- def compute_overlap(img_size, tile_size):
217
- w, h = img_size
218
- if round(w / tile_size) == 1 and round(h / tile_size) == 1:
219
- return 0
220
-
221
- return tile_size // 4
222
-
223
-
224
- def run_torchserve(img, model_path=None, eager_mode=False, opt=None, seg_only=False):
225
- """
226
- eager_mode: not used in this function; put in place to be consistent with run_dask
227
- so that run_wrapper() could call either this function or run_dask with
228
- same syntax
229
- opt: same as eager_mode
230
- seg_only: same as eager_mode
231
- """
232
- buffer = BytesIO()
233
- torch.save(transform(img.resize((opt.scale_size, opt.scale_size))), buffer)
234
-
235
- torchserve_host = os.getenv('TORCHSERVE_HOST', 'http://localhost')
236
- res = requests.post(
237
- f'{torchserve_host}/wfpredict/deepliif',
238
- json={'img': base64.b64encode(buffer.getvalue()).decode('utf-8')}
239
- )
240
-
241
- res.raise_for_status()
242
-
243
- def deserialize_tensor(bs):
244
- return torch.load(BytesIO(base64.b64decode(bs.encode())), map_location=torch.device('cpu'))
245
-
246
- return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
247
-
248
-
249
- def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
250
- model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
251
- nets = init_nets(model_dir, eager_mode, opt)
252
- use_dask = True if opt.norm != 'spectral' else False
253
-
254
- if opt.input_no > 1 or opt.model == 'SDG':
255
- l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
256
- ts = torch.cat(l_ts, dim=1)
257
- else:
258
- ts = transform(img.resize((opt.scale_size, opt.scale_size)))
259
-
260
-
261
- if use_dask:
262
- @delayed
263
- def forward(input, model):
264
- with torch.no_grad():
265
- return model(input.to(next(model.parameters()).device))
266
- else: # some train settings like spectral norm some how in inference mode is not compatible with dask
267
- def forward(input, model):
268
- with torch.no_grad():
269
- return model(input.to(next(model.parameters()).device))
270
-
271
- if opt.model == 'DeepLIIF':
272
- weights = {
273
- 'G51': 0.25, # IHC
274
- 'G52': 0.25, # Hema
275
- 'G53': 0.25, # DAPI
276
- 'G54': 0.00, # Lap2
277
- 'G55': 0.25, # Marker
278
- }
279
- weights = {
280
- 'G51': 0.5, # IHC
281
- 'G52': 0.0, # Hema
282
- 'G53': 0.0, # DAPI
283
- 'G54': 0.0, # Lap2
284
- 'G55': 0.5, # Marker
285
- }
286
-
287
- seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
288
- if seg_only:
289
- seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0}
290
-
291
- lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
292
- #if 'G2' not in seg_map:
293
- # lazy_gens['G2'] = forward(ts, nets['G2'])
294
- if 'G4' not in seg_map:
295
- lazy_gens['G4'] = forward(ts, nets['G4'])
296
- #print(lazy_gens, flush=True)
297
- tstart = time.time()
298
- gens = compute(lazy_gens)[0]
299
- tend = time.time()
300
- global time_g
301
- time_g += (tend - tstart)
302
-
303
- lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
304
- if not seg_only or weights['G51'] != 0:
305
- lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
306
- #print(lazy_segs, '\n', flush=True)
307
- tstart = time.time()
308
- segs = compute(lazy_segs)[0]
309
- tend = time.time()
310
- global time_gs
311
- time_gs += (tend - tstart)
312
-
313
- tstart = time.time()
314
- seg = torch.stack([torch.mul(segs[k], weights[k]) for k in segs.keys()]).sum(dim=0)
315
- tend = time.time()
316
- global time_stack
317
- time_stack += (tend - tstart)
318
-
319
- if seg_only:
320
- res = {'G4': tensor_to_pil(gens['G4'])} if 'G4' in gens else {}
321
- else:
322
- res = {k: tensor_to_pil(v) for k, v in gens.items()}
323
- res.update({k: tensor_to_pil(v) for k, v in segs.items()})
324
- res['G5'] = tensor_to_pil(seg)
325
-
326
- return res
327
- elif opt.model in ['DeepLIIFExt','SDG','CycleGAN']:
328
- if opt.model == 'CycleGAN':
329
- seg_map = {f'GB_{i+1}':None for i in range(opt.modalities_no)} if opt.BtoA else {f'GA_{i+1}':None for i in range(opt.modalities_no)}
330
- else:
331
- seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
332
-
333
- if use_dask:
334
- lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
335
- gens = compute(lazy_gens)[0]
336
- else:
337
- gens = {k: forward(ts, nets[k]) for k in seg_map}
338
-
339
- res = {k: tensor_to_pil(v) for k, v in gens.items()}
340
-
341
- if opt.seg_gen:
342
- if use_dask:
343
- lazy_segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
344
- segs = compute(lazy_segs)[0]
345
- else:
346
- segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
347
- res.update({k: tensor_to_pil(v) for k, v in segs.items()})
348
-
349
- return res
350
- else:
351
- raise Exception(f'run_dask() not fully implemented for {opt.model}')
352
-
353
-
354
- def is_empty(tile):
355
- thresh = 15
356
- if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
357
- return all([True if np.max(image_variance_rgb(t)) < thresh else False for t in tile])
358
- else:
359
- return True if np.max(image_variance_rgb(tile)) < thresh else False
360
-
361
-
362
- def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None, seg_only=False):
363
- if opt.model == 'DeepLIIF':
364
- if is_empty(tile):
365
- if seg_only:
366
- return {
367
- 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
368
- 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
369
- }
370
- else :
371
- return {
372
- 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
373
- 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
374
- 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
375
- 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
376
- 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
377
- 'G51': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
378
- 'G52': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
379
- 'G53': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
380
- 'G54': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
381
- 'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
382
- }
383
- else:
384
- return run_fn(tile, model_path, eager_mode, opt, seg_only)
385
- elif opt.model in ['DeepLIIFExt', 'SDG']:
386
- if is_empty(tile):
387
- res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
388
- res.update({'GS_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)})
389
- return res
390
- else:
391
- return run_fn(tile, model_path, eager_mode, opt)
392
- elif opt.model in ['CycleGAN']:
393
- if is_empty(tile):
394
- net_names = ['GB_{i+1}' for i in range(opt.modalities_no)] if opt.BtoA else [f'GA_{i+1}' for i in range(opt.modalities_no)]
395
- res = {net_name: Image.new(mode='RGB', size=(512, 512)) for net_name in net_names}
396
- return res
397
- else:
398
- return run_fn(tile, model_path, eager_mode, opt)
399
- else:
400
- raise Exception(f'run_wrapper() not implemented for model {opt.model}')
401
-
402
-
403
- def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
404
- eager_mode=False, color_dapi=False, color_marker=False, opt=None,
405
- return_seg_intermediate=False, seg_only=False):
406
- if not opt:
407
- opt = get_opt(model_path)
408
- #print_options(opt)
409
-
410
- run_fn = run_torchserve if use_torchserve else run_dask
411
-
412
- if opt.model == 'SDG':
413
- # SDG could have multiple input images/modalities, hence the input could be a rectangle.
414
- # We split the input to get each modality image then create tiles for each set of input images.
415
- w, h = int(img.width / opt.input_no), img.height
416
- orig = [img.crop((w * i, 0, w * (i+1), h)) for i in range(opt.input_no)]
417
- else:
418
- # Otherwise expect a single input image, which is used directly.
419
- orig = img
420
-
421
- tiler = InferenceTiler(orig, tile_size, overlap_size)
422
- for tile in tiler:
423
- tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt, seg_only))
424
- results = tiler.results()
425
-
426
- if opt.model == 'DeepLIIF':
427
- if seg_only:
428
- images = {'Seg': results['G5']}
429
- if 'G4' in results:
430
- images.update({'Marker': results['G4']})
431
- else:
432
- images = {
433
- 'Hema': results['G1'],
434
- 'DAPI': results['G2'],
435
- 'Lap2': results['G3'],
436
- 'Marker': results['G4'],
437
- 'Seg': results['G5'],
438
- }
439
-
440
- if return_seg_intermediate and not seg_only:
441
- images.update({'IHC_s':results['G51'],
442
- 'Hema_s':results['G52'],
443
- 'DAPI_s':results['G53'],
444
- 'Lap2_s':results['G54'],
445
- 'Marker_s':results['G55'],})
446
-
447
- if color_dapi and not seg_only:
448
- matrix = ( 0, 0, 0, 0,
449
- 299/1000, 587/1000, 114/1000, 0,
450
- 299/1000, 587/1000, 114/1000, 0)
451
- images['DAPI'] = images['DAPI'].convert('RGB', matrix)
452
- if color_marker and not seg_only:
453
- matrix = (299/1000, 587/1000, 114/1000, 0,
454
- 299/1000, 587/1000, 114/1000, 0,
455
- 0, 0, 0, 0)
456
- images['Marker'] = images['Marker'].convert('RGB', matrix)
457
- return images
458
-
459
- elif opt.model == 'DeepLIIFExt':
460
- images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
461
- if opt.seg_gen:
462
- images.update({f'Seg{i}': results[f'GS_{i}'] for i in range(1, opt.modalities_no + 1)})
463
- return images
464
-
465
- elif opt.model == 'SDG':
466
- images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
467
- return images
468
-
469
- else:
470
- #raise Exception(f'inference() not implemented for model {opt.model}')
471
- return results # return result images with default key names (i.e., net names)
472
-
473
-
474
- def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='default', marker_thresh=None, size_thresh_upper=None):
475
- if model == 'DeepLIIF':
476
- resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
477
- overlay, refined, scoring = compute_final_results(
478
- orig, images['Seg'], images.get('Marker'), resolution,
479
- size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
480
- processed_images = {}
481
- processed_images['SegOverlaid'] = Image.fromarray(overlay)
482
- processed_images['SegRefined'] = Image.fromarray(refined)
483
- return processed_images, scoring
484
-
485
- elif model in ['DeepLIIFExt','SDG']:
486
- resolution = '40x' if tile_size > 768 else ('20x' if tile_size > 384 else '10x')
487
- processed_images = {}
488
- scoring = {}
489
- for img_name in list(images.keys()):
490
- if 'Seg' in img_name:
491
- seg_img = images[img_name]
492
- overlay, refined, score = compute_final_results(
493
- orig, images[img_name], None, resolution,
494
- size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
495
-
496
- processed_images[img_name + '_Overlaid'] = Image.fromarray(overlay)
497
- processed_images[img_name + '_Refined'] = Image.fromarray(refined)
498
- scoring[img_name] = score
499
- return processed_images, scoring
500
-
501
- else:
502
- raise Exception(f'postprocess() not implemented for model {model}')
503
-
504
-
505
- def infer_modalities(img, tile_size, model_dir, eager_mode=False,
506
- color_dapi=False, color_marker=False, opt=None,
507
- return_seg_intermediate=False, seg_only=False):
508
- """
509
- This function is used to infer modalities for the given image using a trained model.
510
- :param img: The input image.
511
- :param tile_size: The tile size.
512
- :param model_dir: The directory containing serialized model files.
513
- :return: The inferred modalities and the segmentation mask.
514
- """
515
- if opt is None:
516
- opt = get_opt(model_dir)
517
- opt.use_dp = False
518
- #print_options(opt)
519
-
520
- # for those with multiple input modalities, find the correct size to calculate overlap_size
521
- input_no = opt.input_no if hasattr(opt, 'input_no') else 1
522
- img_size = (img.size[0] / input_no, img.size[1]) # (width, height)
523
-
524
- images = inference(
525
- img,
526
- tile_size=tile_size,
527
- #overlap_size=compute_overlap(img_size, tile_size),
528
- overlap_size=tile_size//16,
529
- model_path=model_dir,
530
- eager_mode=eager_mode,
531
- color_dapi=color_dapi,
532
- color_marker=color_marker,
533
- opt=opt,
534
- return_seg_intermediate=return_seg_intermediate,
535
- seg_only=seg_only
536
- )
537
-
538
- if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
539
- post_images, scoring = postprocess(img, images, tile_size, opt.model)
540
- images = {**images, **post_images}
541
- if seg_only:
542
- delete_keys = [k for k in images.keys() if 'Seg' not in k]
543
- for name in delete_keys:
544
- del images[name]
545
- return images, scoring
546
- else:
547
- return images, None
548
-
549
-
550
- def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000, color_dapi=False, color_marker=False, seg_intermediate=False, seg_only=False):
551
- """
552
- This function infers modalities and segmentation mask for the given WSI image. It
553
-
554
- :param input_dir: The directory containing the WSI.
555
- :param filename: The WSI name.
556
- :param output_dir: The directory for saving the inferred modalities.
557
- :param model_dir: The directory containing the serialized model files.
558
- :param tile_size: The tile size.
559
- :param region_size: The size of each individual region to be processed at once.
560
- :return:
561
- """
562
- basename, _ = os.path.splitext(filename)
563
- results_dir = os.path.join(output_dir, basename)
564
- if not os.path.exists(results_dir):
565
- os.makedirs(results_dir)
566
- size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(os.path.join(input_dir, filename))
567
- rescale = (pixel_type != 'uint8')
568
- print(filename, size_x, size_y, size_z, size_c, size_t, pixel_type, flush=True)
569
-
570
- results = {}
571
- scoring = None
572
-
573
- # javabridge already set up from previous call to get_information()
574
- with bioformats.ImageReader(os.path.join(input_dir, filename)) as reader:
575
- start_x, start_y = 0, 0
576
-
577
- while start_x < size_x:
578
- while start_y < size_y:
579
- print(start_x, start_y, flush=True)
580
- region_XYWH = (start_x, start_y, min(region_size, size_x - start_x), min(region_size, size_y - start_y))
581
- region = reader.read(XYWH=region_XYWH, rescale=rescale)
582
- img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
583
-
584
- region_modalities, region_scoring = infer_modalities(img, tile_size, model_dir, color_dapi=color_dapi, color_marker=color_marker, return_seg_intermediate=seg_intermediate, seg_only=seg_only)
585
- if region_scoring is not None:
586
- if scoring is None:
587
- scoring = {
588
- 'num_pos': region_scoring['num_pos'],
589
- 'num_neg': region_scoring['num_neg'],
590
- }
591
- else:
592
- scoring['num_pos'] += region_scoring['num_pos']
593
- scoring['num_neg'] += region_scoring['num_neg']
594
-
595
- for name, img in region_modalities.items():
596
- if name not in results:
597
- results[name] = np.zeros((size_y, size_x, 3), dtype=np.uint8)
598
- results[name][region_XYWH[1]: region_XYWH[1] + region_XYWH[3],
599
- region_XYWH[0]: region_XYWH[0] + region_XYWH[2]] = np.array(img)
600
- start_y += region_size
601
- start_y = 0
602
- start_x += region_size
603
-
604
- # write_results_to_pickle_file(os.path.join(results_dir, "results.pickle"), results)
605
- # read_results_from_pickle_file(os.path.join(results_dir, "results.pickle"))
606
-
607
- for name, img in results.items():
608
- write_big_tiff_file(os.path.join(results_dir, f'{basename}_{name}.ome.tiff'), img, tile_size)
609
-
610
- if scoring is not None:
611
- scoring['num_total'] = scoring['num_pos'] + scoring['num_neg']
612
- scoring['percent_pos'] = round(scoring['num_pos'] / scoring['num_total'] * 100, 1) if scoring['num_pos'] > 0 else 0
613
- with open(os.path.join(results_dir, f'{basename}.json'), 'w') as f:
614
- json.dump(scoring, f, indent=2)
615
-
616
- javabridge.kill_vm()
617
-
618
-
619
- def get_wsi_resolution(filename):
620
- """
621
- Try to get the resolution (magnification) of the slide and
622
- the corresponding tile size to use by default for DeepLIIF.
623
- If it cannot be found, return (None, None) instead.
624
-
625
- Note: This will start the javabridge VM, but not kill it.
626
- It must be killed elsewhere.
627
-
628
- Parameters
629
- ----------
630
- filename : str
631
- Full path to the file.
632
-
633
- Returns
634
- -------
635
- str :
636
- Magnification (objective power) from image metadata.
637
- int :
638
- Corresponding tile size for DeepLIIF.
639
- """
640
-
641
- # make sure javabridge is already set up from with call to get_information()
642
- size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
643
-
644
- mag = None
645
- metadata = bioformats.get_omexml_metadata(filename)
646
- try:
647
- omexml = bioformats.OMEXML(metadata)
648
- mag = omexml.instrument().Objective.NominalMagnification
649
- except Exception as e:
650
- fields = ['AppMag', 'NominalMagnification']
651
- try:
652
- for field in fields:
653
- idx = metadata.find(field)
654
- if idx >= 0:
655
- for i in range(idx, len(metadata)):
656
- if metadata[i].isdigit() or metadata[i] == '.':
657
- break
658
- for j in range(i, len(metadata)):
659
- if not metadata[j].isdigit() and metadata[j] != '.':
660
- break
661
- if i == j:
662
- continue
663
- mag = metadata[i:j]
664
- break
665
- except Exception as e:
666
- pass
667
-
668
- if mag is None:
669
- return None, None
670
-
671
- try:
672
- tile_size = round((float(mag) / 40) * 512)
673
- return mag, tile_size
674
- except Exception as e:
675
- return None, None
676
-
677
-
678
- def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False):
679
- """
680
- Perform inference on a slide and get the results individual cell data.
681
-
682
- Parameters
683
- ----------
684
- filename : str
685
- Full path to the file.
686
- model_dir : str
687
- Full path to the directory with the DeepLIIF model files.
688
- tile_size : int
689
- Size of tiles to extract and perform inference on.
690
- region_size : int
691
- Maximum size to split the slide for processing.
692
- version : int
693
- Version of cell data to return (3 or 4).
694
- print_log : bool
695
- Whether or not to print updates while processing.
696
-
697
- Returns
698
- -------
699
- dict :
700
- Individual cell data and associated values.
701
- """
702
-
703
- def print_info(*args):
704
- if print_log:
705
- print(*args, flush=True)
706
-
707
- resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
708
-
709
- size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
710
- rescale = (pixel_type != 'uint8')
711
- print_info('Info:', size_x, size_y, size_z, size_c, size_t, pixel_type)
712
-
713
- num_regions_x = math.ceil(size_x / region_size)
714
- num_regions_y = math.ceil(size_y / region_size)
715
- stride_x = math.ceil(size_x / num_regions_x)
716
- stride_y = math.ceil(size_y / num_regions_y)
717
- print_info('Strides:', stride_x, stride_y)
718
-
719
- data = None
720
- default_marker_thresh, count_marker_thresh = 0, 0
721
- default_size_thresh, count_size_thresh = 0, 0
722
-
723
- # javabridge already set up from previous call to get_information()
724
- with bioformats.ImageReader(filename) as reader:
725
- start_x, start_y = 0, 0
726
-
727
- while start_y < size_y:
728
- while start_x < size_x:
729
- region_XYWH = (start_x, start_y, min(stride_x, size_x-start_x), min(stride_y, size_y-start_y))
730
- print_info('Region:', region_XYWH)
731
-
732
- region = reader.read(XYWH=region_XYWH, rescale=rescale)
733
- print_info(region.shape, region.dtype)
734
- img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
735
- print_info(img.size, img.mode)
736
-
737
- images = inference(
738
- img,
739
- tile_size=tile_size,
740
- overlap_size=tile_size//16,
741
- model_path=model_dir,
742
- eager_mode=False,
743
- color_dapi=False,
744
- color_marker=False,
745
- opt=None,
746
- return_seg_intermediate=False,
747
- seg_only=True,
748
- )
749
- region_data = compute_cell_results(images['Seg'], images.get('Marker'), resolution, version=version)
750
-
751
- if start_x != 0 or start_y != 0:
752
- for i in range(len(region_data['cells'])):
753
- cell = decode_cell_data_v4(region_data['cells'][i]) if version == 4 else region_data['cells'][i]
754
- for j in range(2):
755
- cell['bbox'][j] = (cell['bbox'][j][0] + start_x, cell['bbox'][j][1] + start_y)
756
- cell['centroid'] = (cell['centroid'][0] + start_x, cell['centroid'][1] + start_y)
757
- for j in range(len(cell['boundary'])):
758
- cell['boundary'][j] = (cell['boundary'][j][0] + start_x, cell['boundary'][j][1] + start_y)
759
- region_data['cells'][i] = encode_cell_data_v4(cell) if version == 4 else cell
760
-
761
- if data is None:
762
- data = region_data
763
- else:
764
- data['cells'] += region_data['cells']
765
-
766
- if region_data['settings']['default_marker_thresh'] is not None and region_data['settings']['default_marker_thresh'] != 0:
767
- default_marker_thresh += region_data['settings']['default_marker_thresh']
768
- count_marker_thresh += 1
769
- if region_data['settings']['default_size_thresh'] != 0:
770
- default_size_thresh += region_data['settings']['default_size_thresh']
771
- count_size_thresh += 1
772
-
773
- start_x += stride_x
774
-
775
- start_x = 0
776
- start_y += stride_y
777
-
778
- javabridge.kill_vm()
779
-
780
- if count_marker_thresh == 0:
781
- count_marker_thresh = 1
782
- if count_size_thresh == 0:
783
- count_size_thresh = 1
784
- data['settings']['default_marker_thresh'] = round(default_marker_thresh / count_marker_thresh)
785
- data['settings']['default_size_thresh'] = round(default_size_thresh / count_size_thresh)
786
-
787
- global time_g, time_gs
788
- print('Time for g:', round(time_g, 3), flush=True)
789
- print('Time for gs:', round(time_gs, 3), flush=True)
790
- print('Time to stack:', round(time_stack, 3), flush=True)
791
-
792
- return data