deepliif 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,764 +0,0 @@
1
- """This package contains modules related to objective functions, optimizations, and network architectures.
2
-
3
- To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
- You need to implement the following five functions:
5
- -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
- -- <set_input>: unpack data from dataset and apply preprocessing.
7
- -- <forward>: produce intermediate results.
8
- -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
- -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
-
11
- In the function <__init__>, you need to define four lists:
12
- -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
- -- self.model_names (str list): define networks used in our training.
14
- -- self.visual_names (str list): specify the images that you want to display and save.
15
- -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See CycleGAN_model.py for an usage.
16
-
17
- Now you can use the model class by specifying flag '--model dummy'.
18
- See our template model class 'template_model.py' for more details.
19
- """
20
- import base64
21
- import os
22
- import itertools
23
- import importlib
24
- from functools import lru_cache
25
- from io import BytesIO
26
- import json
27
- import math
28
-
29
- import time
30
- time_init = 0
31
- time_empty = 0
32
- time_transform = 0
33
- time_compute = 0
34
-
35
- import requests
36
- import torch
37
- from PIL import Image
38
- Image.MAX_IMAGE_PIXELS = None
39
-
40
- import numpy as np
41
- from dask import delayed, compute
42
- import openslide
43
-
44
- from deepliif.util import *
45
- from deepliif.util.util import tensor_to_pil
46
- from deepliif.data import transform
47
- from deepliif.postprocessing import compute_final_results, compute_cell_results
48
- from deepliif.postprocessing import encode_cell_data_v4, decode_cell_data_v4
49
- from deepliif.options import Options, print_options
50
-
51
- from .base_model import BaseModel
52
-
53
- # import for init purpose, not used in this script
54
- from .DeepLIIF_model import DeepLIIFModel
55
- from .DeepLIIFExt_model import DeepLIIFExtModel
56
-
57
-
58
- @lru_cache
59
- def get_opt(model_dir, mode='test'):
60
- """
61
- mode: test or train, currently only functions used for inference utilize get_opt so it
62
- defaults to test
63
- """
64
- if mode == 'train':
65
- opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
66
- elif mode == 'test':
67
- try:
68
- opt = Options(path_file=os.path.join(model_dir,'test_opt.txt'), mode=mode)
69
- except:
70
- opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
71
- opt.use_dp = False
72
- opt.gpu_ids = list(range(torch.cuda.device_count()))
73
- return opt
74
-
75
-
76
- def find_model_using_name(model_name):
77
- """Import the module "models/[model_name]_model.py".
78
-
79
- In the file, the class called DatasetNameModel() will
80
- be instantiated. It has to be a subclass of BaseModel,
81
- and it is case-insensitive.
82
- """
83
- model_filename = "deepliif.models." + model_name + "_model"
84
- modellib = importlib.import_module(model_filename)
85
- model = None
86
- target_model_name = model_name.replace('_', '') + 'model'
87
- for name, cls in modellib.__dict__.items():
88
- if name.lower() == target_model_name.lower() \
89
- and issubclass(cls, BaseModel):
90
- model = cls
91
-
92
- if model is None:
93
- print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (
94
- model_filename, target_model_name))
95
- exit(0)
96
-
97
- return model
98
-
99
-
100
- def get_option_setter(model_name):
101
- """Return the static method <modify_commandline_options> of the model class."""
102
- model_class = find_model_using_name(model_name)
103
- return model_class.modify_commandline_options
104
-
105
-
106
- def create_model(opt):
107
- """Create a model given the option.
108
-
109
- This function warps the class CustomDatasetDataLoader.
110
- This is the main interface between this package and 'train.py'/'test.py'
111
-
112
- Example:
113
- >>> from deepliif.models import create_model
114
- >>> model = create_model(opt)
115
- """
116
- model = find_model_using_name(opt.model)
117
- instance = model(opt)
118
- print("model [%s] was created" % type(instance).__name__)
119
- return instance
120
-
121
-
122
- def load_torchscript_model(model_pt_path, device):
123
- net = torch.jit.load(model_pt_path, map_location=device)
124
- net = disable_batchnorm_tracking_stats(net)
125
- net.eval()
126
- return net
127
-
128
-
129
-
130
- def load_eager_models(opt, devices=None):
131
- # create a model given model and other options
132
- model = create_model(opt)
133
- # regular setup: load and print networks; create schedulers
134
- model.setup(opt)
135
-
136
- nets = {}
137
- if devices:
138
- model_names = list(devices.keys())
139
- else:
140
- model_names = model.model_names
141
-
142
- for name in model_names:#model.model_names:
143
- if isinstance(name, str):
144
- if '_' in name:
145
- net = getattr(model, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
146
- else:
147
- net = getattr(model, 'net' + name)
148
-
149
- if opt.phase != 'train':
150
- net.eval()
151
- net = disable_batchnorm_tracking_stats(net)
152
-
153
- # SDG models when loaded are still DP.. not sure why
154
- if isinstance(net, torch.nn.DataParallel):
155
- net = net.module
156
-
157
- nets[name] = net
158
- if devices:
159
- nets[name].to(devices[name])
160
-
161
- return nets
162
-
163
-
164
- @lru_cache
165
- def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
166
- """
167
- Init DeepLIIF networks so that every net in
168
- the same group is deployed on the same GPU
169
-
170
- opt_args: to overwrite opt arguments in train_opt.txt, typically used in inference stage
171
- for example, opt_args={'phase':'test'}
172
- """
173
- if opt is None:
174
- opt = get_opt(model_dir, mode=phase)
175
- opt.use_dp = False
176
-
177
- if opt.model == 'DeepLIIF':
178
- net_groups = [
179
- ('G1', 'G52'),
180
- ('G2', 'G53'),
181
- ('G3', 'G54'),
182
- ('G4', 'G55'),
183
- ('G51',)
184
- ]
185
- elif opt.model in ['DeepLIIFExt','SDG']:
186
- if opt.seg_gen:
187
- net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
188
- else:
189
- net_groups = [(f'G_{i+1}',) for i in range(opt.modalities_no)]
190
- elif opt.model == 'CycleGAN':
191
- if opt.BtoA:
192
- net_groups = [(f'GB_{i+1}',) for i in range(opt.modalities_no)]
193
- else:
194
- net_groups = [(f'GA_{i+1}',) for i in range(opt.modalities_no)]
195
- else:
196
- raise Exception(f'init_nets() not implemented for model {opt.model}')
197
-
198
- number_of_gpus_all = torch.cuda.device_count()
199
- number_of_gpus = min(len(opt.gpu_ids),number_of_gpus_all)
200
-
201
- if number_of_gpus > 0:
202
- mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
203
- chunks = [itertools.chain.from_iterable(c) for c in chunker(net_groups, number_of_gpus)]
204
- # chunks = chunks[1:]
205
- devices = {n: torch.device(f'cuda:{mapping_gpu_ids[i]}') for i, g in enumerate(chunks) for n in g}
206
- # devices = {n: torch.device(f'cuda:{i}') for i, g in enumerate(chunks) for n in g}
207
- else:
208
- devices = {n: torch.device('cpu') for n in itertools.chain.from_iterable(net_groups)}
209
-
210
- if eager_mode:
211
- return load_eager_models(opt, devices)
212
-
213
- return {
214
- n: load_torchscript_model(os.path.join(model_dir, f'{n}.pt'), device=d)
215
- for n, d in devices.items()
216
- }
217
-
218
-
219
- def compute_overlap(img_size, tile_size):
220
- w, h = img_size
221
- if round(w / tile_size) == 1 and round(h / tile_size) == 1:
222
- return 0
223
-
224
- return tile_size // 4
225
-
226
-
227
- def run_torchserve(img, model_path=None, eager_mode=False, opt=None, seg_only=False):
228
- """
229
- eager_mode: not used in this function; put in place to be consistent with run_dask
230
- so that run_wrapper() could call either this function or run_dask with
231
- same syntax
232
- opt: same as eager_mode
233
- seg_only: same as eager_mode
234
- """
235
- buffer = BytesIO()
236
- torch.save(transform(img.resize((opt.scale_size, opt.scale_size))), buffer)
237
-
238
- torchserve_host = os.getenv('TORCHSERVE_HOST', 'http://localhost')
239
- res = requests.post(
240
- f'{torchserve_host}/wfpredict/deepliif',
241
- json={'img': base64.b64encode(buffer.getvalue()).decode('utf-8')}
242
- )
243
-
244
- res.raise_for_status()
245
-
246
- def deserialize_tensor(bs):
247
- return torch.load(BytesIO(base64.b64decode(bs.encode())), map_location=torch.device('cpu'))
248
-
249
- return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
250
-
251
-
252
- def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
253
- model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
254
- tstart = time.time()
255
- nets = init_nets(model_dir, eager_mode, opt)
256
- tend = time.time()
257
- global time_init
258
- time_init += (tend - tstart)
259
- use_dask = True if opt.norm != 'spectral' else False
260
-
261
- tstart = time.time()
262
- if opt.input_no > 1 or opt.model == 'SDG':
263
- l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
264
- ts = torch.cat(l_ts, dim=1)
265
- else:
266
- ts = transform(img.resize((opt.scale_size, opt.scale_size)))
267
- tend = time.time()
268
- global time_transform
269
- time_transform += (tend - tstart)
270
-
271
- if use_dask:
272
- @delayed
273
- def forward(input, model):
274
- with torch.no_grad():
275
- return model(input.to(next(model.parameters()).device))
276
- else: # some train settings like spectral norm some how in inference mode is not compatible with dask
277
- def forward(input, model):
278
- with torch.no_grad():
279
- return model(input.to(next(model.parameters()).device))
280
-
281
- if opt.model == 'DeepLIIF':
282
- tstart = time.time()
283
- weights = {
284
- 'G51': 0.25, # IHC
285
- 'G52': 0.25, # Hema
286
- 'G53': 0.25, # DAPI
287
- 'G54': 0.00, # Lap2
288
- 'G55': 0.25, # Marker
289
- }
290
-
291
- seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
292
- if seg_only:
293
- seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0}
294
-
295
- lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
296
- if 'G4' not in seg_map:
297
- lazy_gens['G4'] = forward(ts, nets['G4'])
298
- gens = compute(lazy_gens)[0]
299
-
300
- lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
301
- if not seg_only or weights['G51'] != 0:
302
- lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
303
- segs = compute(lazy_segs)[0]
304
-
305
- seg = torch.stack([torch.mul(segs[k], weights[k]) for k in segs.keys()]).sum(dim=0)
306
-
307
- if seg_only:
308
- res = {'G4': tensor_to_pil(gens['G4'])} if 'G4' in gens else {}
309
- else:
310
- res = {k: tensor_to_pil(v) for k, v in gens.items()}
311
- res.update({k: tensor_to_pil(v) for k, v in segs.items()})
312
- res['G5'] = tensor_to_pil(seg)
313
- tend = time.time()
314
- global time_compute
315
- time_compute += (tend - tstart)
316
-
317
- return res
318
- elif opt.model in ['DeepLIIFExt','SDG','CycleGAN']:
319
- if opt.model == 'CycleGAN':
320
- seg_map = {f'GB_{i+1}':None for i in range(opt.modalities_no)} if opt.BtoA else {f'GA_{i+1}':None for i in range(opt.modalities_no)}
321
- else:
322
- seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
323
-
324
- if use_dask:
325
- lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
326
- gens = compute(lazy_gens)[0]
327
- else:
328
- gens = {k: forward(ts, nets[k]) for k in seg_map}
329
-
330
- res = {k: tensor_to_pil(v) for k, v in gens.items()}
331
-
332
- if opt.seg_gen:
333
- if use_dask:
334
- lazy_segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
335
- segs = compute(lazy_segs)[0]
336
- else:
337
- segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
338
- res.update({k: tensor_to_pil(v) for k, v in segs.items()})
339
-
340
- return res
341
- else:
342
- raise Exception(f'run_dask() not fully implemented for {opt.model}')
343
-
344
-
345
- def is_empty(tile):
346
- thresh = 15
347
- if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
348
- return all([True if np.max(image_variance_rgb(t)) < thresh else False for t in tile])
349
- else:
350
- return True if np.max(image_variance_rgb(tile)) < thresh else False
351
-
352
-
353
- def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None, seg_only=False):
354
- if opt.model == 'DeepLIIF':
355
- #if is_empty(tile):
356
- tstart = time.time()
357
- empty = is_empty(tile)
358
- tend = time.time()
359
- global time_empty
360
- time_empty += (tend - tstart)
361
- if empty:
362
- if seg_only:
363
- return {
364
- 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
365
- 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
366
- }
367
- else :
368
- return {
369
- 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
370
- 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
371
- 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
372
- 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
373
- 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
374
- 'G51': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
375
- 'G52': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
376
- 'G53': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
377
- 'G54': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
378
- 'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
379
- }
380
- else:
381
- return run_fn(tile, model_path, eager_mode, opt, seg_only)
382
- elif opt.model in ['DeepLIIFExt', 'SDG']:
383
- if is_empty(tile):
384
- res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
385
- res.update({'GS_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)})
386
- return res
387
- else:
388
- return run_fn(tile, model_path, eager_mode, opt)
389
- elif opt.model in ['CycleGAN']:
390
- if is_empty(tile):
391
- net_names = ['GB_{i+1}' for i in range(opt.modalities_no)] if opt.BtoA else [f'GA_{i+1}' for i in range(opt.modalities_no)]
392
- res = {net_name: Image.new(mode='RGB', size=(512, 512)) for net_name in net_names}
393
- return res
394
- else:
395
- return run_fn(tile, model_path, eager_mode, opt)
396
- else:
397
- raise Exception(f'run_wrapper() not implemented for model {opt.model}')
398
-
399
-
400
- def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
401
- eager_mode=False, color_dapi=False, color_marker=False, opt=None,
402
- return_seg_intermediate=False, seg_only=False):
403
- if not opt:
404
- opt = get_opt(model_path)
405
- #print_options(opt)
406
-
407
- run_fn = run_torchserve if use_torchserve else run_dask
408
-
409
- if opt.model == 'SDG':
410
- # SDG could have multiple input images/modalities, hence the input could be a rectangle.
411
- # We split the input to get each modality image then create tiles for each set of input images.
412
- w, h = int(img.width / opt.input_no), img.height
413
- orig = [img.crop((w * i, 0, w * (i+1), h)) for i in range(opt.input_no)]
414
- else:
415
- # Otherwise expect a single input image, which is used directly.
416
- orig = img
417
-
418
- tiler = InferenceTiler(orig, tile_size, overlap_size)
419
- for tile in tiler:
420
- tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt, seg_only))
421
- results = tiler.results()
422
-
423
- if opt.model == 'DeepLIIF':
424
- if seg_only:
425
- images = {'Seg': results['G5']}
426
- if 'G4' in results:
427
- images.update({'Marker': results['G4']})
428
- else:
429
- images = {
430
- 'Hema': results['G1'],
431
- 'DAPI': results['G2'],
432
- 'Lap2': results['G3'],
433
- 'Marker': results['G4'],
434
- 'Seg': results['G5'],
435
- }
436
-
437
- if return_seg_intermediate and not seg_only:
438
- images.update({'IHC_s':results['G51'],
439
- 'Hema_s':results['G52'],
440
- 'DAPI_s':results['G53'],
441
- 'Lap2_s':results['G54'],
442
- 'Marker_s':results['G55'],})
443
-
444
- if color_dapi and not seg_only:
445
- matrix = ( 0, 0, 0, 0,
446
- 299/1000, 587/1000, 114/1000, 0,
447
- 299/1000, 587/1000, 114/1000, 0)
448
- images['DAPI'] = images['DAPI'].convert('RGB', matrix)
449
- if color_marker and not seg_only:
450
- matrix = (299/1000, 587/1000, 114/1000, 0,
451
- 299/1000, 587/1000, 114/1000, 0,
452
- 0, 0, 0, 0)
453
- images['Marker'] = images['Marker'].convert('RGB', matrix)
454
- return images
455
-
456
- elif opt.model == 'DeepLIIFExt':
457
- images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
458
- if opt.seg_gen:
459
- images.update({f'Seg{i}': results[f'GS_{i}'] for i in range(1, opt.modalities_no + 1)})
460
- return images
461
-
462
- elif opt.model == 'SDG':
463
- images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
464
- return images
465
-
466
- else:
467
- #raise Exception(f'inference() not implemented for model {opt.model}')
468
- return results # return result images with default key names (i.e., net names)
469
-
470
-
471
- def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='default', marker_thresh=None, size_thresh_upper=None):
472
- if model == 'DeepLIIF':
473
- resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
474
- overlay, refined, scoring = compute_final_results(
475
- orig, images['Seg'], images.get('Marker'), resolution,
476
- size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
477
- processed_images = {}
478
- processed_images['SegOverlaid'] = Image.fromarray(overlay)
479
- processed_images['SegRefined'] = Image.fromarray(refined)
480
- return processed_images, scoring
481
-
482
- elif model in ['DeepLIIFExt','SDG']:
483
- resolution = '40x' if tile_size > 768 else ('20x' if tile_size > 384 else '10x')
484
- processed_images = {}
485
- scoring = {}
486
- for img_name in list(images.keys()):
487
- if 'Seg' in img_name:
488
- seg_img = images[img_name]
489
- overlay, refined, score = compute_final_results(
490
- orig, images[img_name], None, resolution,
491
- size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
492
-
493
- processed_images[img_name + '_Overlaid'] = Image.fromarray(overlay)
494
- processed_images[img_name + '_Refined'] = Image.fromarray(refined)
495
- scoring[img_name] = score
496
- return processed_images, scoring
497
-
498
- else:
499
- raise Exception(f'postprocess() not implemented for model {model}')
500
-
501
-
502
- def infer_modalities(img, tile_size, model_dir, eager_mode=False,
503
- color_dapi=False, color_marker=False, opt=None,
504
- return_seg_intermediate=False):
505
- """
506
- This function is used to infer modalities for the given image using a trained model.
507
- :param img: The input image.
508
- :param tile_size: The tile size.
509
- :param model_dir: The directory containing serialized model files.
510
- :return: The inferred modalities and the segmentation mask.
511
- """
512
- if opt is None:
513
- opt = get_opt(model_dir)
514
- opt.use_dp = False
515
- #print_options(opt)
516
-
517
- # for those with multiple input modalities, find the correct size to calculate overlap_size
518
- input_no = opt.input_no if hasattr(opt, 'input_no') else 1
519
- img_size = (img.size[0] / input_no, img.size[1]) # (width, height)
520
-
521
- images = inference(
522
- img,
523
- tile_size=tile_size,
524
- #overlap_size=compute_overlap(img_size, tile_size),
525
- overlap_size=tile_size//16,
526
- model_path=model_dir,
527
- eager_mode=eager_mode,
528
- color_dapi=color_dapi,
529
- color_marker=color_marker,
530
- opt=opt,
531
- return_seg_intermediate=return_seg_intermediate
532
- )
533
-
534
- if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
535
- post_images, scoring = postprocess(img, images, tile_size, opt.model)
536
- images = {**images, **post_images}
537
- return images, scoring
538
- else:
539
- return images, None
540
-
541
-
542
- def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000):
543
- """
544
- This function infers modalities and segmentation mask for the given WSI image. It
545
-
546
- :param input_dir: The directory containing the WSI.
547
- :param filename: The WSI name.
548
- :param output_dir: The directory for saving the inferred modalities.
549
- :param model_dir: The directory containing the serialized model files.
550
- :param tile_size: The tile size.
551
- :param region_size: The size of each individual region to be processed at once.
552
- :return:
553
- """
554
- basename, _ = os.path.splitext(filename)
555
- results_dir = os.path.join(output_dir, basename)
556
- if not os.path.exists(results_dir):
557
- os.makedirs(results_dir)
558
- size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(os.path.join(input_dir, filename))
559
- rescale = (pixel_type != 'uint8')
560
- print(filename, size_x, size_y, size_z, size_c, size_t, pixel_type)
561
-
562
- results = {}
563
- scoring = None
564
-
565
- # javabridge already set up from previous call to get_information()
566
- with bioformats.ImageReader(os.path.join(input_dir, filename)) as reader:
567
- start_x, start_y = 0, 0
568
-
569
- while start_x < size_x:
570
- while start_y < size_y:
571
- print(start_x, start_y)
572
- region_XYWH = (start_x, start_y, min(region_size, size_x - start_x), min(region_size, size_y - start_y))
573
- region = reader.read(XYWH=region_XYWH, rescale=rescale)
574
- img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
575
-
576
- region_modalities, region_scoring = infer_modalities(img, tile_size, model_dir)
577
- if region_scoring is not None:
578
- if scoring is None:
579
- scoring = {
580
- 'num_pos': region_scoring['num_pos'],
581
- 'num_neg': region_scoring['num_neg'],
582
- }
583
- else:
584
- scoring['num_pos'] += region_scoring['num_pos']
585
- scoring['num_neg'] += region_scoring['num_neg']
586
-
587
- for name, img in region_modalities.items():
588
- if name not in results:
589
- results[name] = np.zeros((size_y, size_x, 3), dtype=np.uint8)
590
- results[name][region_XYWH[1]: region_XYWH[1] + region_XYWH[3],
591
- region_XYWH[0]: region_XYWH[0] + region_XYWH[2]] = np.array(img)
592
- start_y += region_size
593
- start_y = 0
594
- start_x += region_size
595
-
596
- # write_results_to_pickle_file(os.path.join(results_dir, "results.pickle"), results)
597
- # read_results_from_pickle_file(os.path.join(results_dir, "results.pickle"))
598
-
599
- for name, img in results.items():
600
- write_big_tiff_file(os.path.join(results_dir, f'{basename}_{name}.ome.tiff'), img, tile_size)
601
-
602
- if scoring is not None:
603
- scoring['num_total'] = scoring['num_pos'] + scoring['num_neg']
604
- scoring['percent_pos'] = round(scoring['num_pos'] / scoring['num_total'] * 100, 1) if scoring['num_pos'] > 0 else 0
605
- with open(os.path.join(results_dir, f'{basename}.json'), 'w') as f:
606
- json.dump(scoring, f, indent=2)
607
-
608
- javabridge.kill_vm()
609
-
610
-
611
- def get_wsi_resolution(filename):
612
- """
613
- Use OpenSlide to get the resolution (magnification) of the slide
614
- and the corresponding tile size to use by default for DeepLIIF.
615
- If it cannot be found, return (None, None) instead.
616
-
617
- Parameters
618
- ----------
619
- filename : str
620
- Full path to the file.
621
-
622
- Returns
623
- -------
624
- str :
625
- Magnification (objective power) as found by OpenSlide.
626
- int :
627
- Corresponding tile size for DeepLIIF.
628
- """
629
- try:
630
- image = openslide.OpenSlide(filename)
631
- mag = image.properties.get(openslide.PROPERTY_NAME_OBJECTIVE_POWER)
632
- tile_size = round((float(mag) / 40) * 512)
633
- return mag, tile_size
634
- except Exception as e:
635
- return None, None
636
-
637
-
638
- def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False):
639
- """
640
- Perform inference on a slide and get the results individual cell data.
641
-
642
- Parameters
643
- ----------
644
- filename : str
645
- Full path to the file.
646
- model_dir : str
647
- Full path to the directory with the DeepLIIF model files.
648
- tile_size : int
649
- Size of tiles to extract and perform inference on.
650
- region_size : int
651
- Maximum size to split the slide for processing.
652
- version : int
653
- Version of cell data to return (3 or 4).
654
- print_log : bool
655
- Whether or not to print updates while processing.
656
-
657
- Returns
658
- -------
659
- dict :
660
- Individual cell data and associated values.
661
- """
662
-
663
- def print_info(*args):
664
- if print_log:
665
- print(*args, flush=True)
666
-
667
- resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
668
-
669
- size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
670
- rescale = (pixel_type != 'uint8')
671
- print_info('Info:', size_x, size_y, size_z, size_c, size_t, pixel_type)
672
-
673
- num_regions_x = math.ceil(size_x / region_size)
674
- num_regions_y = math.ceil(size_y / region_size)
675
- stride_x = math.ceil(size_x / num_regions_x)
676
- stride_y = math.ceil(size_y / num_regions_y)
677
- print_info('Strides:', stride_x, stride_y)
678
-
679
- data = None
680
- default_marker_thresh, count_marker_thresh = 0, 0
681
- default_size_thresh, count_size_thresh = 0, 0
682
-
683
- time_inference = 0
684
- time_postprocess = 0
685
-
686
- # javabridge already set up from previous call to get_information()
687
- with bioformats.ImageReader(filename) as reader:
688
- start_x, start_y = 0, 0
689
-
690
- while start_y < size_y:
691
- while start_x < size_x:
692
- region_XYWH = (start_x, start_y, min(stride_x, size_x-start_x), min(stride_y, size_y-start_y))
693
- print_info('Region:', region_XYWH)
694
-
695
- region = reader.read(XYWH=region_XYWH, rescale=rescale)
696
- print_info(region.shape, region.dtype)
697
- img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
698
- print_info(img.size, img.mode)
699
-
700
- tstart = time.time()
701
- images = inference(
702
- img,
703
- tile_size=tile_size,
704
- overlap_size=tile_size//16,
705
- model_path=model_dir,
706
- eager_mode=False,
707
- color_dapi=False,
708
- color_marker=False,
709
- opt=None,
710
- return_seg_intermediate=False,
711
- seg_only=True,
712
- )
713
- tend = time.time()
714
- time_inference += (tend - tstart)
715
- tstart = time.time()
716
- region_data = compute_cell_results(images['Seg'], images.get('Marker'), resolution, version=version)
717
- tend = time.time()
718
- time_postprocess += (tend - tstart)
719
-
720
- if start_x != 0 or start_y != 0:
721
- for i in range(len(region_data['cells'])):
722
- cell = decode_cell_data_v4(region_data['cells'][i]) if version == 4 else region_data['cells'][i]
723
- for j in range(2):
724
- cell['bbox'][j] = (cell['bbox'][j][0] + start_x, cell['bbox'][j][1] + start_y)
725
- cell['centroid'] = (cell['centroid'][0] + start_x, cell['centroid'][1] + start_y)
726
- for j in range(len(cell['boundary'])):
727
- cell['boundary'][j] = (cell['boundary'][j][0] + start_x, cell['boundary'][j][1] + start_y)
728
- region_data['cells'][i] = encode_cell_data_v4(cell) if version == 4 else cell
729
-
730
- if data is None:
731
- data = region_data
732
- else:
733
- data['cells'] += region_data['cells']
734
-
735
- if region_data['settings']['default_marker_thresh'] is not None and region_data['settings']['default_marker_thresh'] != 0:
736
- default_marker_thresh += region_data['settings']['default_marker_thresh']
737
- count_marker_thresh += 1
738
- if region_data['settings']['default_size_thresh'] != 0:
739
- default_size_thresh += region_data['settings']['default_size_thresh']
740
- count_size_thresh += 1
741
-
742
- start_x += stride_x
743
-
744
- start_x = 0
745
- start_y += stride_y
746
-
747
- javabridge.kill_vm()
748
-
749
- if count_marker_thresh == 0:
750
- count_marker_thresh = 1
751
- if count_size_thresh == 0:
752
- count_size_thresh = 1
753
- data['settings']['default_marker_thresh'] = round(default_marker_thresh / count_marker_thresh)
754
- data['settings']['default_size_thresh'] = round(default_size_thresh / count_size_thresh)
755
-
756
- global time_empty
757
- print('Time init:', round(time_init, 3), flush=True)
758
- print('Time empty:', round(time_empty, 3), flush=True)
759
- print('Time transform:', round(time_transform, 3), flush=True)
760
- print('Time compute:', round(time_compute, 3), flush=True)
761
- print('Time inference:', round(time_inference, 3), flush=True)
762
- print('Time postprocess:', round(time_postprocess, 3), flush=True)
763
-
764
- return data