deepliif 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,792 +0,0 @@
1
- """This package contains modules related to objective functions, optimizations, and network architectures.
2
-
3
- To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
- You need to implement the following five functions:
5
- -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
- -- <set_input>: unpack data from dataset and apply preprocessing.
7
- -- <forward>: produce intermediate results.
8
- -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
- -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
-
11
- In the function <__init__>, you need to define four lists:
12
- -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
- -- self.model_names (str list): define networks used in our training.
14
- -- self.visual_names (str list): specify the images that you want to display and save.
15
- -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See CycleGAN_model.py for an usage.
16
-
17
- Now you can use the model class by specifying flag '--model dummy'.
18
- See our template model class 'template_model.py' for more details.
19
- """
20
- import base64
21
- import os
22
- import itertools
23
- import importlib
24
- from functools import lru_cache
25
- from io import BytesIO
26
- import json
27
- import math
28
-
29
- import requests
30
- import torch
31
- from PIL import Image
32
- Image.MAX_IMAGE_PIXELS = None
33
-
34
- import numpy as np
35
- from dask import delayed, compute
36
-
37
- from deepliif.util import *
38
- from deepliif.util.util import tensor_to_pil
39
- from deepliif.data import transform
40
- from deepliif.postprocessing import compute_final_results, compute_cell_results
41
- from deepliif.postprocessing import encode_cell_data_v4, decode_cell_data_v4
42
- from deepliif.options import Options, print_options
43
-
44
- from .base_model import BaseModel
45
-
46
- # import for init purpose, not used in this script
47
- from .DeepLIIF_model import DeepLIIFModel
48
- from .DeepLIIFExt_model import DeepLIIFExtModel
49
-
50
-
51
- @lru_cache
52
- def get_opt(model_dir, mode='test'):
53
- """
54
- mode: test or train, currently only functions used for inference utilize get_opt so it
55
- defaults to test
56
- """
57
- if mode == 'train':
58
- opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
59
- elif mode == 'test':
60
- try:
61
- opt = Options(path_file=os.path.join(model_dir,'test_opt.txt'), mode=mode)
62
- except:
63
- opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
64
- opt.use_dp = False
65
- opt.gpu_ids = list(range(torch.cuda.device_count()))
66
- return opt
67
-
68
-
69
- def find_model_using_name(model_name):
70
- """Import the module "models/[model_name]_model.py".
71
-
72
- In the file, the class called DatasetNameModel() will
73
- be instantiated. It has to be a subclass of BaseModel,
74
- and it is case-insensitive.
75
- """
76
- model_filename = "deepliif.models." + model_name + "_model"
77
- modellib = importlib.import_module(model_filename)
78
- model = None
79
- target_model_name = model_name.replace('_', '') + 'model'
80
- for name, cls in modellib.__dict__.items():
81
- if name.lower() == target_model_name.lower() \
82
- and issubclass(cls, BaseModel):
83
- model = cls
84
-
85
- if model is None:
86
- print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (
87
- model_filename, target_model_name))
88
- exit(0)
89
-
90
- return model
91
-
92
-
93
- def get_option_setter(model_name):
94
- """Return the static method <modify_commandline_options> of the model class."""
95
- model_class = find_model_using_name(model_name)
96
- return model_class.modify_commandline_options
97
-
98
-
99
- def create_model(opt):
100
- """Create a model given the option.
101
-
102
- This function warps the class CustomDatasetDataLoader.
103
- This is the main interface between this package and 'train.py'/'test.py'
104
-
105
- Example:
106
- >>> from deepliif.models import create_model
107
- >>> model = create_model(opt)
108
- """
109
- model = find_model_using_name(opt.model)
110
- instance = model(opt)
111
- print("model [%s] was created" % type(instance).__name__)
112
- return instance
113
-
114
-
115
- def load_torchscript_model(model_pt_path, device):
116
- net = torch.jit.load(model_pt_path, map_location=device)
117
- net = disable_batchnorm_tracking_stats(net)
118
- net.eval()
119
- return net
120
-
121
-
122
-
123
- def load_eager_models(opt, devices=None):
124
- # create a model given model and other options
125
- model = create_model(opt)
126
- # regular setup: load and print networks; create schedulers
127
- model.setup(opt)
128
-
129
- nets = {}
130
- if devices:
131
- model_names = list(devices.keys())
132
- else:
133
- model_names = model.model_names
134
-
135
- for name in model_names:#model.model_names:
136
- if isinstance(name, str):
137
- if '_' in name:
138
- net = getattr(model, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
139
- else:
140
- net = getattr(model, 'net' + name)
141
-
142
- if opt.phase != 'train':
143
- net.eval()
144
- net = disable_batchnorm_tracking_stats(net)
145
-
146
- # SDG models when loaded are still DP.. not sure why
147
- if isinstance(net, torch.nn.DataParallel):
148
- net = net.module
149
-
150
- nets[name] = net
151
- if devices:
152
- nets[name].to(devices[name])
153
-
154
- return nets
155
-
156
-
157
- @lru_cache
158
- def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
159
- """
160
- Init DeepLIIF networks so that every net in
161
- the same group is deployed on the same GPU
162
-
163
- opt_args: to overwrite opt arguments in train_opt.txt, typically used in inference stage
164
- for example, opt_args={'phase':'test'}
165
- """
166
- if opt is None:
167
- opt = get_opt(model_dir, mode=phase)
168
- opt.use_dp = False
169
-
170
- if opt.model == 'DeepLIIF':
171
- net_groups = [
172
- ('G1', 'G52'),
173
- ('G2', 'G53'),
174
- ('G3', 'G54'),
175
- ('G4', 'G55'),
176
- ('G51',)
177
- ]
178
- elif opt.model in ['DeepLIIFExt','SDG']:
179
- if opt.seg_gen:
180
- net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
181
- else:
182
- net_groups = [(f'G_{i+1}',) for i in range(opt.modalities_no)]
183
- elif opt.model == 'CycleGAN':
184
- if opt.BtoA:
185
- net_groups = [(f'GB_{i+1}',) for i in range(opt.modalities_no)]
186
- else:
187
- net_groups = [(f'GA_{i+1}',) for i in range(opt.modalities_no)]
188
- else:
189
- raise Exception(f'init_nets() not implemented for model {opt.model}')
190
-
191
- number_of_gpus_all = torch.cuda.device_count()
192
- number_of_gpus = min(len(opt.gpu_ids),number_of_gpus_all)
193
-
194
- if number_of_gpus > 0:
195
- mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
196
- chunks = [itertools.chain.from_iterable(c) for c in chunker(net_groups, number_of_gpus)]
197
- # chunks = chunks[1:]
198
- devices = {n: torch.device(f'cuda:{mapping_gpu_ids[i]}') for i, g in enumerate(chunks) for n in g}
199
- # devices = {n: torch.device(f'cuda:{i}') for i, g in enumerate(chunks) for n in g}
200
- else:
201
- devices = {n: torch.device('cpu') for n in itertools.chain.from_iterable(net_groups)}
202
-
203
- if eager_mode:
204
- return load_eager_models(opt, devices)
205
-
206
- return {
207
- n: load_torchscript_model(os.path.join(model_dir, f'{n}.pt'), device=d)
208
- for n, d in devices.items()
209
- }
210
-
211
-
212
- def compute_overlap(img_size, tile_size):
213
- w, h = img_size
214
- if round(w / tile_size) == 1 and round(h / tile_size) == 1:
215
- return 0
216
-
217
- return tile_size // 4
218
-
219
-
220
- def run_torchserve(img, model_path=None, eager_mode=False, opt=None, seg_only=False):
221
- """
222
- eager_mode: not used in this function; put in place to be consistent with run_dask
223
- so that run_wrapper() could call either this function or run_dask with
224
- same syntax
225
- opt: same as eager_mode
226
- seg_only: same as eager_mode
227
- """
228
- buffer = BytesIO()
229
- torch.save(transform(img.resize((opt.scale_size, opt.scale_size))), buffer)
230
-
231
- torchserve_host = os.getenv('TORCHSERVE_HOST', 'http://localhost')
232
- res = requests.post(
233
- f'{torchserve_host}/wfpredict/deepliif',
234
- json={'img': base64.b64encode(buffer.getvalue()).decode('utf-8')}
235
- )
236
-
237
- res.raise_for_status()
238
-
239
- def deserialize_tensor(bs):
240
- return torch.load(BytesIO(base64.b64decode(bs.encode())), map_location=torch.device('cpu'))
241
-
242
- return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
243
-
244
-
245
- def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
246
- model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
247
- nets = init_nets(model_dir, eager_mode, opt)
248
- use_dask = True if opt.norm != 'spectral' else False
249
-
250
- if opt.input_no > 1 or opt.model == 'SDG':
251
- l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
252
- ts = torch.cat(l_ts, dim=1)
253
- else:
254
- ts = transform(img.resize((opt.scale_size, opt.scale_size)))
255
-
256
-
257
- if use_dask:
258
- @delayed
259
- def forward(input, model):
260
- with torch.no_grad():
261
- return model(input.to(next(model.parameters()).device))
262
- else: # some train settings like spectral norm some how in inference mode is not compatible with dask
263
- def forward(input, model):
264
- with torch.no_grad():
265
- return model(input.to(next(model.parameters()).device))
266
-
267
- if opt.model == 'DeepLIIF':
268
- #weights = {
269
- # 'G51': 0.25, # IHC
270
- # 'G52': 0.25, # Hema
271
- # 'G53': 0.25, # DAPI
272
- # 'G54': 0.00, # Lap2
273
- # 'G55': 0.25, # Marker
274
- #}
275
- weights = {
276
- 'G51': 0.5, # IHC
277
- 'G52': 0.0, # Hema
278
- 'G53': 0.0, # DAPI
279
- 'G54': 0.0, # Lap2
280
- 'G55': 0.5, # Marker
281
- }
282
-
283
- seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
284
- if seg_only:
285
- seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0}
286
-
287
- lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
288
- if 'G4' not in seg_map:
289
- lazy_gens['G4'] = forward(ts, nets['G4'])
290
- gens = compute(lazy_gens)[0]
291
-
292
- lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
293
- if not seg_only or weights['G51'] != 0:
294
- lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
295
- segs = compute(lazy_segs)[0]
296
-
297
- seg = torch.stack([torch.mul(segs[k], weights[k]) for k in segs.keys()]).sum(dim=0)
298
-
299
- if seg_only:
300
- res = {'G4': tensor_to_pil(gens['G4'])} if 'G4' in gens else {}
301
- else:
302
- res = {k: tensor_to_pil(v) for k, v in gens.items()}
303
- res.update({k: tensor_to_pil(v) for k, v in segs.items()})
304
- res['G5'] = tensor_to_pil(seg)
305
-
306
- return res
307
- elif opt.model in ['DeepLIIFExt','SDG','CycleGAN']:
308
- if opt.model == 'CycleGAN':
309
- seg_map = {f'GB_{i+1}':None for i in range(opt.modalities_no)} if opt.BtoA else {f'GA_{i+1}':None for i in range(opt.modalities_no)}
310
- else:
311
- seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
312
-
313
- if use_dask:
314
- lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
315
- gens = compute(lazy_gens)[0]
316
- else:
317
- gens = {k: forward(ts, nets[k]) for k in seg_map}
318
-
319
- res = {k: tensor_to_pil(v) for k, v in gens.items()}
320
-
321
- if opt.seg_gen:
322
- if use_dask:
323
- lazy_segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
324
- segs = compute(lazy_segs)[0]
325
- else:
326
- segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
327
- res.update({k: tensor_to_pil(v) for k, v in segs.items()})
328
-
329
- return res
330
- else:
331
- raise Exception(f'run_dask() not fully implemented for {opt.model}')
332
-
333
-
334
- def is_empty(tile):
335
- thresh = 15
336
- if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
337
- return all([True if np.max(image_variance_rgb(t)) < thresh else False for t in tile])
338
- else:
339
- return True if np.max(image_variance_rgb(tile)) < thresh else False
340
- #return True if image_variance_gray(tile) < thresh else False
341
- #return True if image_variance_gray(tile.resize((128, 128), resample=Image.NEAREST)) < thresh else False
342
- #return True if image_variance_gray(tile.convert('L').resize((128, 128), resample=Image.NEAREST)) < thresh else False
343
- #return True if image_variance_gray(tile.convert('L').resize((64, 64), resample=Image.NEAREST)) < thresh else False
344
-
345
-
346
- count_run_tiles = 0
347
- def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None, seg_only=False):
348
- if opt.model == 'DeepLIIF':
349
- if is_empty(tile):
350
- if seg_only:
351
- return {
352
- 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
353
- 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
354
- }
355
- else :
356
- return {
357
- 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
358
- 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
359
- 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
360
- 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
361
- 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
362
- 'G51': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
363
- 'G52': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
364
- 'G53': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
365
- 'G54': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
366
- 'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
367
- }
368
- else:
369
- global count_run_tiles
370
- count_run_tiles += 1
371
- return run_fn(tile, model_path, eager_mode, opt, seg_only)
372
- elif opt.model in ['DeepLIIFExt', 'SDG']:
373
- if is_empty(tile):
374
- res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
375
- res.update({'GS_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)})
376
- return res
377
- else:
378
- return run_fn(tile, model_path, eager_mode, opt)
379
- elif opt.model in ['CycleGAN']:
380
- if is_empty(tile):
381
- net_names = ['GB_{i+1}' for i in range(opt.modalities_no)] if opt.BtoA else [f'GA_{i+1}' for i in range(opt.modalities_no)]
382
- res = {net_name: Image.new(mode='RGB', size=(512, 512)) for net_name in net_names}
383
- return res
384
- else:
385
- return run_fn(tile, model_path, eager_mode, opt)
386
- else:
387
- raise Exception(f'run_wrapper() not implemented for model {opt.model}')
388
-
389
-
390
- def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
391
- eager_mode=False, color_dapi=False, color_marker=False, opt=None,
392
- return_seg_intermediate=False, seg_only=False):
393
- if not opt:
394
- opt = get_opt(model_path)
395
- #print_options(opt)
396
-
397
- run_fn = run_torchserve if use_torchserve else run_dask
398
-
399
- if opt.model == 'SDG':
400
- # SDG could have multiple input images/modalities, hence the input could be a rectangle.
401
- # We split the input to get each modality image then create tiles for each set of input images.
402
- w, h = int(img.width / opt.input_no), img.height
403
- orig = [img.crop((w * i, 0, w * (i+1), h)) for i in range(opt.input_no)]
404
- else:
405
- # Otherwise expect a single input image, which is used directly.
406
- orig = img
407
-
408
- tiler = InferenceTiler(orig, tile_size, overlap_size)
409
- for tile in tiler:
410
- tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt, seg_only))
411
- results = tiler.results()
412
-
413
- if opt.model == 'DeepLIIF':
414
- if seg_only:
415
- images = {'Seg': results['G5']}
416
- if 'G4' in results:
417
- images.update({'Marker': results['G4']})
418
- else:
419
- images = {
420
- 'Hema': results['G1'],
421
- 'DAPI': results['G2'],
422
- 'Lap2': results['G3'],
423
- 'Marker': results['G4'],
424
- 'Seg': results['G5'],
425
- }
426
-
427
- if return_seg_intermediate and not seg_only:
428
- images.update({'IHC_s':results['G51'],
429
- 'Hema_s':results['G52'],
430
- 'DAPI_s':results['G53'],
431
- 'Lap2_s':results['G54'],
432
- 'Marker_s':results['G55'],})
433
-
434
- if color_dapi and not seg_only:
435
- matrix = ( 0, 0, 0, 0,
436
- 299/1000, 587/1000, 114/1000, 0,
437
- 299/1000, 587/1000, 114/1000, 0)
438
- images['DAPI'] = images['DAPI'].convert('RGB', matrix)
439
- if color_marker and not seg_only:
440
- matrix = (299/1000, 587/1000, 114/1000, 0,
441
- 299/1000, 587/1000, 114/1000, 0,
442
- 0, 0, 0, 0)
443
- images['Marker'] = images['Marker'].convert('RGB', matrix)
444
- return images
445
-
446
- elif opt.model == 'DeepLIIFExt':
447
- images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
448
- if opt.seg_gen:
449
- images.update({f'Seg{i}': results[f'GS_{i}'] for i in range(1, opt.modalities_no + 1)})
450
- return images
451
-
452
- elif opt.model == 'SDG':
453
- images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
454
- return images
455
-
456
- else:
457
- #raise Exception(f'inference() not implemented for model {opt.model}')
458
- return results # return result images with default key names (i.e., net names)
459
-
460
-
461
- def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='default', marker_thresh=None, size_thresh_upper=None):
462
- if model == 'DeepLIIF':
463
- resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
464
- overlay, refined, scoring = compute_final_results(
465
- orig, images['Seg'], images.get('Marker'), resolution,
466
- size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
467
- processed_images = {}
468
- processed_images['SegOverlaid'] = Image.fromarray(overlay)
469
- processed_images['SegRefined'] = Image.fromarray(refined)
470
- return processed_images, scoring
471
-
472
- elif model in ['DeepLIIFExt','SDG']:
473
- resolution = '40x' if tile_size > 768 else ('20x' if tile_size > 384 else '10x')
474
- processed_images = {}
475
- scoring = {}
476
- for img_name in list(images.keys()):
477
- if 'Seg' in img_name:
478
- seg_img = images[img_name]
479
- overlay, refined, score = compute_final_results(
480
- orig, images[img_name], None, resolution,
481
- size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
482
-
483
- processed_images[img_name + '_Overlaid'] = Image.fromarray(overlay)
484
- processed_images[img_name + '_Refined'] = Image.fromarray(refined)
485
- scoring[img_name] = score
486
- return processed_images, scoring
487
-
488
- else:
489
- raise Exception(f'postprocess() not implemented for model {model}')
490
-
491
-
492
- def infer_modalities(img, tile_size, model_dir, eager_mode=False,
493
- color_dapi=False, color_marker=False, opt=None,
494
- return_seg_intermediate=False, seg_only=False):
495
- """
496
- This function is used to infer modalities for the given image using a trained model.
497
- :param img: The input image.
498
- :param tile_size: The tile size.
499
- :param model_dir: The directory containing serialized model files.
500
- :return: The inferred modalities and the segmentation mask.
501
- """
502
- if opt is None:
503
- opt = get_opt(model_dir)
504
- opt.use_dp = False
505
- #print_options(opt)
506
-
507
- # for those with multiple input modalities, find the correct size to calculate overlap_size
508
- input_no = opt.input_no if hasattr(opt, 'input_no') else 1
509
- img_size = (img.size[0] / input_no, img.size[1]) # (width, height)
510
-
511
- images = inference(
512
- img,
513
- tile_size=tile_size,
514
- #overlap_size=compute_overlap(img_size, tile_size),
515
- overlap_size=tile_size//16,
516
- model_path=model_dir,
517
- eager_mode=eager_mode,
518
- color_dapi=color_dapi,
519
- color_marker=color_marker,
520
- opt=opt,
521
- return_seg_intermediate=return_seg_intermediate,
522
- seg_only=seg_only
523
- )
524
-
525
- if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
526
- post_images, scoring = postprocess(img, images, tile_size, opt.model)
527
- images = {**images, **post_images}
528
- if seg_only:
529
- delete_keys = [k for k in images.keys() if 'Seg' not in k]
530
- for name in delete_keys:
531
- del images[name]
532
- return images, scoring
533
- else:
534
- return images, None
535
-
536
-
537
- def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000, color_dapi=False, color_marker=False, seg_intermediate=False, seg_only=False):
538
- """
539
- This function infers modalities and segmentation mask for the given WSI image. It
540
-
541
- :param input_dir: The directory containing the WSI.
542
- :param filename: The WSI name.
543
- :param output_dir: The directory for saving the inferred modalities.
544
- :param model_dir: The directory containing the serialized model files.
545
- :param tile_size: The tile size.
546
- :param region_size: The size of each individual region to be processed at once.
547
- :return:
548
- """
549
- basename, _ = os.path.splitext(filename)
550
- results_dir = os.path.join(output_dir, basename)
551
- if not os.path.exists(results_dir):
552
- os.makedirs(results_dir)
553
- size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(os.path.join(input_dir, filename))
554
- rescale = (pixel_type != 'uint8')
555
- print(filename, size_x, size_y, size_z, size_c, size_t, pixel_type, flush=True)
556
-
557
- results = {}
558
- scoring = None
559
-
560
- # javabridge already set up from previous call to get_information()
561
- with bioformats.ImageReader(os.path.join(input_dir, filename)) as reader:
562
- start_x, start_y = 0, 0
563
-
564
- while start_x < size_x:
565
- while start_y < size_y:
566
- print(start_x, start_y, flush=True)
567
- region_XYWH = (start_x, start_y, min(region_size, size_x - start_x), min(region_size, size_y - start_y))
568
- region = reader.read(XYWH=region_XYWH, rescale=rescale)
569
- img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
570
-
571
- region_modalities, region_scoring = infer_modalities(img, tile_size, model_dir, color_dapi=color_dapi, color_marker=color_marker, return_seg_intermediate=seg_intermediate, seg_only=seg_only)
572
- if region_scoring is not None:
573
- if scoring is None:
574
- scoring = {
575
- 'num_pos': region_scoring['num_pos'],
576
- 'num_neg': region_scoring['num_neg'],
577
- }
578
- else:
579
- scoring['num_pos'] += region_scoring['num_pos']
580
- scoring['num_neg'] += region_scoring['num_neg']
581
-
582
- for name, img in region_modalities.items():
583
- if name not in results:
584
- results[name] = np.zeros((size_y, size_x, 3), dtype=np.uint8)
585
- results[name][region_XYWH[1]: region_XYWH[1] + region_XYWH[3],
586
- region_XYWH[0]: region_XYWH[0] + region_XYWH[2]] = np.array(img)
587
- start_y += region_size
588
- start_y = 0
589
- start_x += region_size
590
-
591
- # write_results_to_pickle_file(os.path.join(results_dir, "results.pickle"), results)
592
- # read_results_from_pickle_file(os.path.join(results_dir, "results.pickle"))
593
-
594
- for name, img in results.items():
595
- write_big_tiff_file(os.path.join(results_dir, f'{basename}_{name}.ome.tiff'), img, tile_size)
596
-
597
- if scoring is not None:
598
- scoring['num_total'] = scoring['num_pos'] + scoring['num_neg']
599
- scoring['percent_pos'] = round(scoring['num_pos'] / scoring['num_total'] * 100, 1) if scoring['num_pos'] > 0 else 0
600
- with open(os.path.join(results_dir, f'{basename}.json'), 'w') as f:
601
- json.dump(scoring, f, indent=2)
602
-
603
- javabridge.kill_vm()
604
-
605
-
606
- def get_wsi_resolution(filename):
607
- """
608
- Try to get the resolution (magnification) of the slide and
609
- the corresponding tile size to use by default for DeepLIIF.
610
- If it cannot be found, return (None, None) instead.
611
-
612
- Note: This will start the javabridge VM, but not kill it.
613
- It must be killed elsewhere.
614
-
615
- Parameters
616
- ----------
617
- filename : str
618
- Full path to the file.
619
-
620
- Returns
621
- -------
622
- str :
623
- Magnification (objective power) from image metadata.
624
- int :
625
- Corresponding tile size for DeepLIIF.
626
- """
627
-
628
- # make sure javabridge is already set up from with call to get_information()
629
- size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
630
-
631
- mag = None
632
- metadata = bioformats.get_omexml_metadata(filename)
633
- try:
634
- omexml = bioformats.OMEXML(metadata)
635
- mag = omexml.instrument().Objective.NominalMagnification
636
- except Exception as e:
637
- fields = ['AppMag', 'NominalMagnification']
638
- try:
639
- for field in fields:
640
- idx = metadata.find(field)
641
- if idx >= 0:
642
- for i in range(idx, len(metadata)):
643
- if metadata[i].isdigit() or metadata[i] == '.':
644
- break
645
- for j in range(i, len(metadata)):
646
- if not metadata[j].isdigit() and metadata[j] != '.':
647
- break
648
- if i == j:
649
- continue
650
- mag = metadata[i:j]
651
- break
652
- except Exception as e:
653
- pass
654
-
655
- if mag is None:
656
- return None, None
657
-
658
- try:
659
- tile_size = round((float(mag) / 40) * 512)
660
- return mag, tile_size
661
- except Exception as e:
662
- return None, None
663
-
664
-
665
- from tifffile import TiffFile
666
- import zarr
667
- import time
668
-
669
- def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False):
670
- """
671
- Perform inference on a slide and get the results individual cell data.
672
-
673
- Parameters
674
- ----------
675
- filename : str
676
- Full path to the file.
677
- model_dir : str
678
- Full path to the directory with the DeepLIIF model files.
679
- tile_size : int
680
- Size of tiles to extract and perform inference on.
681
- region_size : int
682
- Maximum size to split the slide for processing.
683
- version : int
684
- Version of cell data to return (3 or 4).
685
- print_log : bool
686
- Whether or not to print updates while processing.
687
-
688
- Returns
689
- -------
690
- dict :
691
- Individual cell data and associated values.
692
- """
693
-
694
- def print_info(*args):
695
- if print_log:
696
- print(*args, flush=True)
697
-
698
- resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
699
-
700
- size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
701
- rescale = (pixel_type != 'uint8')
702
- print_info('Info:', size_x, size_y, size_z, size_c, size_t, pixel_type)
703
-
704
- num_regions_x = math.ceil(size_x / region_size)
705
- num_regions_y = math.ceil(size_y / region_size)
706
- stride_x = math.ceil(size_x / num_regions_x)
707
- stride_y = math.ceil(size_y / num_regions_y)
708
- print_info('Strides:', stride_x, stride_y)
709
-
710
- data = None
711
- default_marker_thresh, count_marker_thresh = 0, 0
712
- default_size_thresh, count_size_thresh = 0, 0
713
-
714
- # javabridge already set up from previous call to get_information()
715
- with bioformats.ImageReader(filename) as reader:
716
- start_x, start_y = 0, 0
717
-
718
- while start_y < size_y:
719
- while start_x < size_x:
720
- region_XYWH = (start_x, start_y, min(stride_x, size_x-start_x), min(stride_y, size_y-start_y))
721
- print_info('Region:', region_XYWH)
722
-
723
- tstart_read = time.time()
724
-
725
- region = reader.read(XYWH=region_XYWH, rescale=rescale)
726
- print_info(region.shape, region.dtype)
727
- img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
728
- print_info(img.size, img.mode)
729
-
730
- #x, y, w, h = region_XYWH
731
- #print(region_XYWH, x, y, w, h, flush=True)
732
- #tif = TiffFile(filename)
733
- #z = zarr.open(tif.pages[0].aszarr(), mode='r')
734
- #img = Image.fromarray(z[y:y+h, x:x+w])
735
- #print_info(img.size, img.mode)
736
-
737
- tend_read = time.time()
738
- print('Time to read region:', round(tend_read-tstart_read, 1), 'sec.', flush=True)
739
-
740
- images = inference(
741
- img,
742
- tile_size=tile_size,
743
- overlap_size=tile_size//16,
744
- model_path=model_dir,
745
- eager_mode=False,
746
- color_dapi=False,
747
- color_marker=False,
748
- opt=None,
749
- return_seg_intermediate=False,
750
- seg_only=True,
751
- )
752
- region_data = compute_cell_results(images['Seg'], images.get('Marker'), resolution, version=version)
753
-
754
- if start_x != 0 or start_y != 0:
755
- for i in range(len(region_data['cells'])):
756
- cell = decode_cell_data_v4(region_data['cells'][i]) if version == 4 else region_data['cells'][i]
757
- for j in range(2):
758
- cell['bbox'][j] = (cell['bbox'][j][0] + start_x, cell['bbox'][j][1] + start_y)
759
- cell['centroid'] = (cell['centroid'][0] + start_x, cell['centroid'][1] + start_y)
760
- for j in range(len(cell['boundary'])):
761
- cell['boundary'][j] = (cell['boundary'][j][0] + start_x, cell['boundary'][j][1] + start_y)
762
- region_data['cells'][i] = encode_cell_data_v4(cell) if version == 4 else cell
763
-
764
- if data is None:
765
- data = region_data
766
- else:
767
- data['cells'] += region_data['cells']
768
-
769
- if region_data['settings']['default_marker_thresh'] is not None and region_data['settings']['default_marker_thresh'] != 0:
770
- default_marker_thresh += region_data['settings']['default_marker_thresh']
771
- count_marker_thresh += 1
772
- if region_data['settings']['default_size_thresh'] != 0:
773
- default_size_thresh += region_data['settings']['default_size_thresh']
774
- count_size_thresh += 1
775
-
776
- start_x += stride_x
777
-
778
- start_x = 0
779
- start_y += stride_y
780
-
781
- javabridge.kill_vm()
782
- global count_run_tiles
783
- print('Num tiles run:', count_run_tiles, flush=True)
784
-
785
- if count_marker_thresh == 0:
786
- count_marker_thresh = 1
787
- if count_size_thresh == 0:
788
- count_size_thresh = 1
789
- data['settings']['default_marker_thresh'] = round(default_marker_thresh / count_marker_thresh)
790
- data['settings']['default_size_thresh'] = round(default_size_thresh / count_size_thresh)
791
-
792
- return data