deepliif 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,943 +0,0 @@
1
- """This package contains modules related to objective functions, optimizations, and network architectures.
2
-
3
- To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
- You need to implement the following five functions:
5
- -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
- -- <set_input>: unpack data from dataset and apply preprocessing.
7
- -- <forward>: produce intermediate results.
8
- -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
- -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
-
11
- In the function <__init__>, you need to define four lists:
12
- -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
- -- self.model_names (str list): define networks used in our training.
14
- -- self.visual_names (str list): specify the images that you want to display and save.
15
- -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See CycleGAN_model.py for an usage.
16
-
17
- Now you can use the model class by specifying flag '--model dummy'.
18
- See our template model class 'template_model.py' for more details.
19
- """
20
- import base64
21
- import os
22
- import itertools
23
- import importlib
24
- from functools import lru_cache
25
- from io import BytesIO
26
- import json
27
- import math
28
-
29
- import requests
30
- import torch
31
- from PIL import Image
32
- Image.MAX_IMAGE_PIXELS = None
33
-
34
- import numpy as np
35
- from dask import delayed, compute
36
- import openslide
37
-
38
- from deepliif.util import *
39
- from deepliif.util.util import tensor_to_pil
40
- from deepliif.data import transform
41
- from deepliif.postprocessing import compute_final_results, compute_cell_results
42
- from deepliif.postprocessing import encode_cell_data_v4, decode_cell_data_v4
43
- from deepliif.options import Options, print_options
44
-
45
- from .base_model import BaseModel
46
-
47
- # import for init purpose, not used in this script
48
- from .DeepLIIF_model import DeepLIIFModel
49
- from .DeepLIIFExt_model import DeepLIIFExtModel
50
-
51
-
52
- @lru_cache
53
- def get_opt(model_dir, mode='test'):
54
- """
55
- mode: test or train, currently only functions used for inference utilize get_opt so it
56
- defaults to test
57
- """
58
- if mode == 'train':
59
- opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
60
- elif mode == 'test':
61
- try:
62
- opt = Options(path_file=os.path.join(model_dir,'test_opt.txt'), mode=mode)
63
- except:
64
- opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
65
- opt.use_dp = False
66
- opt.gpu_ids = list(range(torch.cuda.device_count()))
67
- return opt
68
-
69
-
70
- def find_model_using_name(model_name):
71
- """Import the module "models/[model_name]_model.py".
72
-
73
- In the file, the class called DatasetNameModel() will
74
- be instantiated. It has to be a subclass of BaseModel,
75
- and it is case-insensitive.
76
- """
77
- model_filename = "deepliif.models." + model_name + "_model"
78
- modellib = importlib.import_module(model_filename)
79
- model = None
80
- target_model_name = model_name.replace('_', '') + 'model'
81
- for name, cls in modellib.__dict__.items():
82
- if name.lower() == target_model_name.lower() \
83
- and issubclass(cls, BaseModel):
84
- model = cls
85
-
86
- if model is None:
87
- print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (
88
- model_filename, target_model_name))
89
- exit(0)
90
-
91
- return model
92
-
93
-
94
- def get_option_setter(model_name):
95
- """Return the static method <modify_commandline_options> of the model class."""
96
- model_class = find_model_using_name(model_name)
97
- return model_class.modify_commandline_options
98
-
99
-
100
- def create_model(opt):
101
- """Create a model given the option.
102
-
103
- This function warps the class CustomDatasetDataLoader.
104
- This is the main interface between this package and 'train.py'/'test.py'
105
-
106
- Example:
107
- >>> from deepliif.models import create_model
108
- >>> model = create_model(opt)
109
- """
110
- model = find_model_using_name(opt.model)
111
- instance = model(opt)
112
- print("model [%s] was created" % type(instance).__name__)
113
- return instance
114
-
115
-
116
- def load_torchscript_model(model_pt_path, device):
117
- net = torch.jit.load(model_pt_path, map_location=device)
118
- net = disable_batchnorm_tracking_stats(net)
119
- net.eval()
120
- return net
121
-
122
-
123
-
124
- def load_eager_models(opt, devices=None):
125
- # create a model given model and other options
126
- model = create_model(opt)
127
- # regular setup: load and print networks; create schedulers
128
- model.setup(opt)
129
-
130
- nets = {}
131
- if devices:
132
- model_names = list(devices.keys())
133
- else:
134
- model_names = model.model_names
135
-
136
- for name in model_names:#model.model_names:
137
- if isinstance(name, str):
138
- if '_' in name:
139
- net = getattr(model, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
140
- else:
141
- net = getattr(model, 'net' + name)
142
-
143
- if opt.phase != 'train':
144
- net.eval()
145
- net = disable_batchnorm_tracking_stats(net)
146
-
147
- # SDG models when loaded are still DP.. not sure why
148
- if isinstance(net, torch.nn.DataParallel):
149
- net = net.module
150
-
151
- nets[name] = net
152
- if devices:
153
- nets[name].to(devices[name])
154
-
155
- return nets
156
-
157
-
158
- @lru_cache
159
- def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
160
- """
161
- Init DeepLIIF networks so that every net in
162
- the same group is deployed on the same GPU
163
-
164
- opt_args: to overwrite opt arguments in train_opt.txt, typically used in inference stage
165
- for example, opt_args={'phase':'test'}
166
- """
167
- if opt is None:
168
- opt = get_opt(model_dir, mode=phase)
169
- opt.use_dp = False
170
-
171
- if opt.model == 'DeepLIIF':
172
- net_groups = [
173
- ('G1', 'G52'),
174
- ('G2', 'G53'),
175
- ('G3', 'G54'),
176
- ('G4', 'G55'),
177
- ('G51',)
178
- ]
179
- elif opt.model in ['DeepLIIFExt','SDG']:
180
- if opt.seg_gen:
181
- net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
182
- else:
183
- net_groups = [(f'G_{i+1}',) for i in range(opt.modalities_no)]
184
- elif opt.model == 'CycleGAN':
185
- if opt.BtoA:
186
- net_groups = [(f'GB_{i+1}',) for i in range(opt.modalities_no)]
187
- else:
188
- net_groups = [(f'GA_{i+1}',) for i in range(opt.modalities_no)]
189
- else:
190
- raise Exception(f'init_nets() not implemented for model {opt.model}')
191
-
192
- number_of_gpus_all = torch.cuda.device_count()
193
- number_of_gpus = min(len(opt.gpu_ids),number_of_gpus_all)
194
-
195
- if number_of_gpus > 0:
196
- mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
197
- chunks = [itertools.chain.from_iterable(c) for c in chunker(net_groups, number_of_gpus)]
198
- # chunks = chunks[1:]
199
- devices = {n: torch.device(f'cuda:{mapping_gpu_ids[i]}') for i, g in enumerate(chunks) for n in g}
200
- # devices = {n: torch.device(f'cuda:{i}') for i, g in enumerate(chunks) for n in g}
201
- else:
202
- devices = {n: torch.device('cpu') for n in itertools.chain.from_iterable(net_groups)}
203
-
204
- if eager_mode:
205
- return load_eager_models(opt, devices)
206
-
207
- return {
208
- n: load_torchscript_model(os.path.join(model_dir, f'{n}.pt'), device=d)
209
- for n, d in devices.items()
210
- }
211
-
212
-
213
- def compute_overlap(img_size, tile_size):
214
- w, h = img_size
215
- if round(w / tile_size) == 1 and round(h / tile_size) == 1:
216
- return 0
217
-
218
- return tile_size // 4
219
-
220
-
221
- def run_torchserve(img, model_path=None, eager_mode=False, opt=None, seg_only=False):
222
- """
223
- eager_mode: not used in this function; put in place to be consistent with run_dask
224
- so that run_wrapper() could call either this function or run_dask with
225
- same syntax
226
- opt: same as eager_mode
227
- seg_only: same as eager_mode
228
- """
229
- buffer = BytesIO()
230
- torch.save(transform(img.resize((opt.scale_size, opt.scale_size))), buffer)
231
-
232
- torchserve_host = os.getenv('TORCHSERVE_HOST', 'http://localhost')
233
- res = requests.post(
234
- f'{torchserve_host}/wfpredict/deepliif',
235
- json={'img': base64.b64encode(buffer.getvalue()).decode('utf-8')}
236
- )
237
-
238
- res.raise_for_status()
239
-
240
- def deserialize_tensor(bs):
241
- return torch.load(BytesIO(base64.b64decode(bs.encode())), map_location=torch.device('cpu'))
242
-
243
- return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
244
-
245
-
246
- def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
247
- model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
248
- nets = init_nets(model_dir, eager_mode, opt)
249
- use_dask = True if opt.norm != 'spectral' else False
250
-
251
- if opt.input_no > 1 or opt.model == 'SDG':
252
- l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
253
- ts = torch.cat(l_ts, dim=1)
254
- else:
255
- ts = transform(img.resize((opt.scale_size, opt.scale_size)))
256
-
257
-
258
- if use_dask:
259
- @delayed
260
- def forward(input, model):
261
- with torch.no_grad():
262
- return model(input.to(next(model.parameters()).device))
263
- else: # some train settings like spectral norm some how in inference mode is not compatible with dask
264
- def forward(input, model):
265
- with torch.no_grad():
266
- return model(input.to(next(model.parameters()).device))
267
-
268
- if opt.model == 'DeepLIIF':
269
- weights = {
270
- 'G51': 0.25, # IHC
271
- 'G52': 0.25, # Hema
272
- 'G53': 0.25, # DAPI
273
- 'G54': 0.00, # Lap2
274
- 'G55': 0.25, # Marker
275
- }
276
-
277
- seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
278
- if seg_only:
279
- seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0}
280
-
281
- lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
282
- gens = compute(lazy_gens)[0]
283
-
284
- lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
285
- if weights['G51'] != 0:
286
- lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
287
- segs = compute(lazy_segs)[0]
288
-
289
- seg = torch.stack([torch.mul(segs[k], weights[k]) for k in segs.keys()]).sum(dim=0)
290
-
291
- if seg_only:
292
- res = {'G4': tensor_to_pil(gens['G4'])} if 'G4' in gens else {}
293
- else:
294
- res = {k: tensor_to_pil(v) for k, v in gens.items()}
295
- res.update({k: tensor_to_pil(v) for k, v in segs.items()})
296
- res['G5'] = tensor_to_pil(seg)
297
-
298
- return res
299
- elif opt.model in ['DeepLIIFExt','SDG','CycleGAN']:
300
- if opt.model == 'CycleGAN':
301
- seg_map = {f'GB_{i+1}':None for i in range(opt.modalities_no)} if opt.BtoA else {f'GA_{i+1}':None for i in range(opt.modalities_no)}
302
- else:
303
- seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
304
-
305
- if use_dask:
306
- lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
307
- gens = compute(lazy_gens)[0]
308
- else:
309
- gens = {k: forward(ts, nets[k]) for k in seg_map}
310
-
311
- res = {k: tensor_to_pil(v) for k, v in gens.items()}
312
-
313
- if opt.seg_gen:
314
- if use_dask:
315
- lazy_segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
316
- segs = compute(lazy_segs)[0]
317
- else:
318
- segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
319
- res.update({k: tensor_to_pil(v) for k, v in segs.items()})
320
-
321
- return res
322
- else:
323
- raise Exception(f'run_dask() not fully implemented for {opt.model}')
324
-
325
-
326
- def run_dask_multi(imgs, model_path, eager_mode=False, opt=None, seg_only=False):
327
- model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
328
- nets = init_nets(model_dir, eager_mode, opt)
329
- use_dask = True if opt.norm != 'spectral' else False
330
-
331
- if opt.input_no > 1 or opt.model == 'SDG':
332
- raise Exception(f'run_dask_multi() not fully implemented for {opt.model}')
333
- #l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
334
- #ts = torch.cat(l_ts, dim=1)
335
- else:
336
- tss = [transform(img.resize((opt.scale_size, opt.scale_size))) for img in imgs]
337
-
338
-
339
- if use_dask:
340
- @delayed
341
- def forward(input, model):
342
- with torch.no_grad():
343
- return model(input.to(next(model.parameters()).device))
344
- else: # some train settings like spectral norm some how in inference mode is not compatible with dask
345
- def forward(input, model):
346
- with torch.no_grad():
347
- return model(input.to(next(model.parameters()).device))
348
-
349
- if opt.model == 'DeepLIIF':
350
- weights = {
351
- 'G51': 0.25, # IHC
352
- 'G52': 0.25, # Hema
353
- 'G53': 0.25, # DAPI
354
- 'G54': 0.00, # Lap2
355
- 'G55': 0.25, # Marker
356
- }
357
-
358
- seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
359
- if seg_only:
360
- seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0}
361
-
362
- lazy_gens = [{k: forward(ts, nets[k]) for k in seg_map} for ts in tss]
363
- gens = compute(*lazy_gens)
364
-
365
- lazy_segs = [{v: forward(g[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()} for g in gens]
366
- if weights['G51'] != 0:
367
- for i in range(len(lazy_segs)):
368
- lazy_segs[i]['G51'] = forward(tss[i], nets['G51']).to(torch.device('cpu'))
369
- segs = compute(*lazy_segs)
370
-
371
- seg = [torch.stack([torch.mul(s[k], weights[k]) for k in s.keys()]).sum(dim=0) for s in segs]
372
-
373
- if seg_only:
374
- res = [{'G4': tensor_to_pil(g['G4'])} if 'G4' in g else {} for g in gens]
375
- else:
376
- raise Exception(f'run_dask_multi() not fully implemented for {opt.model}')
377
- #res = {k: tensor_to_pil(v) for k, v in gens.items()}
378
- #res.update({k: tensor_to_pil(v) for k, v in segs.items()})
379
- for i in range(len(seg)):
380
- res[i]['G5'] = tensor_to_pil(seg[i])
381
-
382
- return res
383
-
384
- else:
385
- raise Exception(f'run_dask_multi() not fully implemented for {opt.model}')
386
-
387
-
388
- def run_dask_multi2(imgs, model_path, eager_mode=False, opt=None, seg_only=False):
389
- model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
390
- nets = init_nets(model_dir, eager_mode, opt)
391
- use_dask = True if opt.norm != 'spectral' else False
392
-
393
- if opt.input_no > 1 or opt.model == 'SDG':
394
- raise Exception(f'run_dask_multi() not fully implemented for {opt.model}')
395
- #l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
396
- #ts = torch.cat(l_ts, dim=1)
397
- else:
398
- tss = [transform(img.resize((opt.scale_size, opt.scale_size))) for img in imgs]
399
-
400
-
401
- if use_dask:
402
- @delayed
403
- def forward(input, model):
404
- with torch.no_grad():
405
- return model(input.to(next(model.parameters()).device))
406
- else: # some train settings like spectral norm some how in inference mode is not compatible with dask
407
- def forward(input, model):
408
- with torch.no_grad():
409
- return model(input.to(next(model.parameters()).device))
410
-
411
- if opt.model == 'DeepLIIF':
412
- weights = {
413
- 'G51': 0.25, # IHC
414
- 'G52': 0.25, # Hema
415
- 'G53': 0.25, # DAPI
416
- 'G54': 0.00, # Lap2
417
- 'G55': 0.25, # Marker
418
- }
419
-
420
- seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
421
- if seg_only:
422
- seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0}
423
-
424
- lazy_gens = {k: [forward(ts, nets[k]) for ts in tss] for k in seg_map}
425
- gens = compute(lazy_gens)[0]
426
-
427
- lazy_segs = {v: [forward(g, nets[v]).to(torch.device('cpu')) for g in gens[k]] for k, v in seg_map.items()}
428
- if weights['G51'] != 0:
429
- lazy_segs['G51'] = [forward(ts, nets['G51']).to(torch.device('cpu')) for ts in tss]
430
- segs = compute(lazy_segs)[0]
431
-
432
- seg = [torch.stack([torch.mul(segs[k][i], weights[k]) for k in segs.keys()]).sum(dim=0) for i in range(len(tss))]
433
-
434
- if seg_only:
435
- res = [{'G4': tensor_to_pil(gens['G4'][i])} if 'G4' in gens else {} for i in range(len(tss))]
436
- else:
437
- raise Exception(f'run_dask_multi() not fully implemented for {opt.model}')
438
- #res = {k: tensor_to_pil(v) for k, v in gens.items()}
439
- #res.update({k: tensor_to_pil(v) for k, v in segs.items()})
440
- for i in range(len(seg)):
441
- res[i]['G5'] = tensor_to_pil(seg[i])
442
-
443
- return res
444
-
445
- else:
446
- raise Exception(f'run_dask_multi() not fully implemented for {opt.model}')
447
-
448
-
449
- def run_dask_multi3(imgs, model_path, eager_mode=False, opt=None, seg_only=False):
450
- model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
451
- nets = init_nets(model_dir, eager_mode, opt)
452
- #print(nets['G1'], flush=True)
453
- #print(next(nets['G1'].parameters()), flush=True)
454
- use_dask = True if opt.norm != 'spectral' else False
455
-
456
- '''
457
- with torch.no_grad():
458
- ts = [transform(img.resize((512, 512))) for img in imgs]
459
- #ts = [transform(imgs[0]), transform(imgs[0]), transform(imgs[0]), transform(imgs[0])]
460
- for i in range(0):
461
- for img in imgs:
462
- ts.append(transform(img.resize((512, 512))))
463
-
464
- num_iters = 100
465
-
466
- tstart = time.time()
467
- for i in range(num_iters):
468
- tcat = torch.cat(ts)
469
- rs = nets['G1'](tcat.to('cuda:0'))
470
- rs = torch.split(rs, 1)
471
- tend = time.time()
472
- for i, r in enumerate(rs):
473
- print('cat result size:', r.size(), flush=True)
474
- im = tensor_to_pil(r)
475
- im.save(f'test_wsi_G51_{i}_cat.png')
476
- print('time:', tend-tstart, 'sec.', flush=True)
477
-
478
- tstart = time.time()
479
- for i in range(num_iters):
480
- rs = []
481
- for t in ts:
482
- rs.append(nets['G1'](t.to('cuda:0')))
483
- tend = time.time()
484
- for i, r in enumerate(rs):
485
- print('ind result size:', r.size(), flush=True)
486
- im = tensor_to_pil(r)
487
- im.save(f'test_wsi_G51_{i}_ind.png')
488
- print('time:', tend-tstart, 'sec.', flush=True)
489
- '''
490
-
491
- if opt.input_no > 1 or opt.model == 'SDG':
492
- raise Exception(f'run_dask_multi() not fully implemented for {opt.model}')
493
- #l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
494
- #ts = torch.cat(l_ts, dim=1)
495
- else:
496
- #ts = transform(img.resize((opt.scale_size, opt.scale_size)))
497
- ts = torch.cat([transform(img.resize((opt.scale_size, opt.scale_size))) for img in imgs])
498
-
499
-
500
- if use_dask:
501
- @delayed
502
- def forward(input, model):
503
- with torch.no_grad():
504
- return model(input.to(next(model.parameters()).device))
505
- else: # some train settings like spectral norm some how in inference mode is not compatible with dask
506
- def forward(input, model):
507
- with torch.no_grad():
508
- return model(input.to(next(model.parameters()).device))
509
-
510
- if opt.model == 'DeepLIIF':
511
- weights = {
512
- 'G51': 0.25, # IHC
513
- 'G52': 0.25, # Hema
514
- 'G53': 0.25, # DAPI
515
- 'G54': 0.00, # Lap2
516
- 'G55': 0.25, # Marker
517
- }
518
-
519
- seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
520
- if seg_only:
521
- seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0}
522
-
523
- lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
524
- gens = compute(lazy_gens)[0]
525
-
526
- lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
527
- if weights['G51'] != 0:
528
- lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
529
- segs = compute(lazy_segs)[0]
530
-
531
- seg = torch.stack([torch.mul(segs[k], weights[k]) for k in segs.keys()]).sum(dim=0)
532
-
533
- if seg_only:
534
- #res = {'G4': tensor_to_pil(gens['G4'])} if 'G4' in gens else {}
535
- g4 = torch.split(gens['G4'], 1) # SHOULD DO CHECK IF G4 IN gens
536
- #print(type(g4), g4[0].size(), flush=True)
537
- res = [{'G4': tensor_to_pil(g)} for g in g4] # SHOULD DO CHECK IF G4 IN gens
538
- else:
539
- raise Exception(f'run_dask_multi() not fully implemented for {opt.model}')
540
- #res = {k: tensor_to_pil(v) for k, v in gens.items()}
541
- #res.update({k: tensor_to_pil(v) for k, v in segs.items()})
542
- #res['G5'] = tensor_to_pil(seg)
543
- g5 = torch.split(seg, 1)
544
- for i in range(len(g5)):
545
- res[i]['G5'] = tensor_to_pil(g5[i])
546
-
547
- return res
548
-
549
- else:
550
- raise Exception(f'run_dask_multi3() not fully implemented for {opt.model}')
551
-
552
-
553
- def is_empty(tile):
554
- thresh = 15
555
- if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
556
- return all([True if np.max(image_variance_rgb(t)) < thresh else False for t in tile])
557
- else:
558
- return True if np.max(image_variance_rgb(tile)) < thresh else False
559
-
560
-
561
- def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None, seg_only=False):
562
- if opt.model == 'DeepLIIF':
563
- if is_empty(tile):
564
- return {
565
- 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
566
- 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
567
- 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
568
- 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
569
- 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
570
- 'G51': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
571
- 'G52': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
572
- 'G53': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
573
- 'G54': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
574
- 'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
575
- }
576
- else:
577
- return run_fn(tile, model_path, eager_mode, opt, seg_only)
578
- elif opt.model in ['DeepLIIFExt', 'SDG']:
579
- if is_empty(tile):
580
- res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
581
- res.update({'GS_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)})
582
- return res
583
- else:
584
- return run_fn(tile, model_path, eager_mode, opt)
585
- elif opt.model in ['CycleGAN']:
586
- if is_empty(tile):
587
- net_names = ['GB_{i+1}' for i in range(opt.modalities_no)] if opt.BtoA else [f'GA_{i+1}' for i in range(opt.modalities_no)]
588
- res = {net_name: Image.new(mode='RGB', size=(512, 512)) for net_name in net_names}
589
- return res
590
- else:
591
- return run_fn(tile, model_path, eager_mode, opt)
592
- else:
593
- raise Exception(f'run_wrapper() not implemented for model {opt.model}')
594
-
595
-
596
- def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
597
- eager_mode=False, color_dapi=False, color_marker=False, opt=None,
598
- return_seg_intermediate=False, seg_only=False):
599
- if not opt:
600
- opt = get_opt(model_path)
601
- #print_options(opt)
602
-
603
- run_fn = run_torchserve if use_torchserve else run_dask
604
-
605
- if opt.model == 'SDG':
606
- # SDG could have multiple input images/modalities, hence the input could be a rectangle.
607
- # We split the input to get each modality image then create tiles for each set of input images.
608
- w, h = int(img.width / opt.input_no), img.height
609
- orig = [img.crop((w * i, 0, w * (i+1), h)) for i in range(opt.input_no)]
610
- else:
611
- # Otherwise expect a single input image, which is used directly.
612
- orig = img
613
-
614
- tiler = InferenceTiler(orig, tile_size, overlap_size)
615
- for tile in tiler:
616
- tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt, seg_only))
617
- results = tiler.results()
618
-
619
- if opt.model == 'DeepLIIF':
620
- if seg_only:
621
- images = {'Seg': results['G5']}
622
- if 'G4' in results:
623
- images.update({'Marker': results['G4']})
624
- else:
625
- images = {
626
- 'Hema': results['G1'],
627
- 'DAPI': results['G2'],
628
- 'Lap2': results['G3'],
629
- 'Marker': results['G4'],
630
- 'Seg': results['G5'],
631
- }
632
-
633
- if return_seg_intermediate and not seg_only:
634
- images.update({'IHC_s':results['G51'],
635
- 'Hema_s':results['G52'],
636
- 'DAPI_s':results['G53'],
637
- 'Lap2_s':results['G54'],
638
- 'Marker_s':results['G55'],})
639
-
640
- if color_dapi and not seg_only:
641
- matrix = ( 0, 0, 0, 0,
642
- 299/1000, 587/1000, 114/1000, 0,
643
- 299/1000, 587/1000, 114/1000, 0)
644
- images['DAPI'] = images['DAPI'].convert('RGB', matrix)
645
- if color_marker and not seg_only:
646
- matrix = (299/1000, 587/1000, 114/1000, 0,
647
- 299/1000, 587/1000, 114/1000, 0,
648
- 0, 0, 0, 0)
649
- images['Marker'] = images['Marker'].convert('RGB', matrix)
650
- return images
651
-
652
- elif opt.model == 'DeepLIIFExt':
653
- images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
654
- if opt.seg_gen:
655
- images.update({f'Seg{i}': results[f'GS_{i}'] for i in range(1, opt.modalities_no + 1)})
656
- return images
657
-
658
- elif opt.model == 'SDG':
659
- images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
660
- return images
661
-
662
- else:
663
- #raise Exception(f'inference() not implemented for model {opt.model}')
664
- return results # return result images with default key names (i.e., net names)
665
-
666
-
667
- def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='default', marker_thresh=None, size_thresh_upper=None):
668
- if model == 'DeepLIIF':
669
- resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
670
- overlay, refined, scoring = compute_final_results(
671
- orig, images['Seg'], images.get('Marker'), resolution,
672
- size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
673
- processed_images = {}
674
- processed_images['SegOverlaid'] = Image.fromarray(overlay)
675
- processed_images['SegRefined'] = Image.fromarray(refined)
676
- return processed_images, scoring
677
-
678
- elif model in ['DeepLIIFExt','SDG']:
679
- resolution = '40x' if tile_size > 768 else ('20x' if tile_size > 384 else '10x')
680
- processed_images = {}
681
- scoring = {}
682
- for img_name in list(images.keys()):
683
- if 'Seg' in img_name:
684
- seg_img = images[img_name]
685
- overlay, refined, score = compute_final_results(
686
- orig, images[img_name], None, resolution,
687
- size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
688
-
689
- processed_images[img_name + '_Overlaid'] = Image.fromarray(overlay)
690
- processed_images[img_name + '_Refined'] = Image.fromarray(refined)
691
- scoring[img_name] = score
692
- return processed_images, scoring
693
-
694
- else:
695
- raise Exception(f'postprocess() not implemented for model {model}')
696
-
697
-
698
- def infer_modalities(img, tile_size, model_dir, eager_mode=False,
699
- color_dapi=False, color_marker=False, opt=None,
700
- return_seg_intermediate=False):
701
- """
702
- This function is used to infer modalities for the given image using a trained model.
703
- :param img: The input image.
704
- :param tile_size: The tile size.
705
- :param model_dir: The directory containing serialized model files.
706
- :return: The inferred modalities and the segmentation mask.
707
- """
708
- if opt is None:
709
- opt = get_opt(model_dir)
710
- opt.use_dp = False
711
- #print_options(opt)
712
-
713
- # for those with multiple input modalities, find the correct size to calculate overlap_size
714
- input_no = opt.input_no if hasattr(opt, 'input_no') else 1
715
- img_size = (img.size[0] / input_no, img.size[1]) # (width, height)
716
-
717
- images = inference(
718
- img,
719
- tile_size=tile_size,
720
- #overlap_size=compute_overlap(img_size, tile_size),
721
- overlap_size=tile_size//16,
722
- model_path=model_dir,
723
- eager_mode=eager_mode,
724
- color_dapi=color_dapi,
725
- color_marker=color_marker,
726
- opt=opt,
727
- return_seg_intermediate=return_seg_intermediate
728
- )
729
-
730
- if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
731
- post_images, scoring = postprocess(img, images, tile_size, opt.model)
732
- images = {**images, **post_images}
733
- return images, scoring
734
- else:
735
- return images, None
736
-
737
-
738
- def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000):
739
- """
740
- This function infers modalities and segmentation mask for the given WSI image. It
741
-
742
- :param input_dir: The directory containing the WSI.
743
- :param filename: The WSI name.
744
- :param output_dir: The directory for saving the inferred modalities.
745
- :param model_dir: The directory containing the serialized model files.
746
- :param tile_size: The tile size.
747
- :param region_size: The size of each individual region to be processed at once.
748
- :return:
749
- """
750
- basename, _ = os.path.splitext(filename)
751
- results_dir = os.path.join(output_dir, basename)
752
- if not os.path.exists(results_dir):
753
- os.makedirs(results_dir)
754
- size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(os.path.join(input_dir, filename))
755
- rescale = (pixel_type != 'uint8')
756
- print(filename, size_x, size_y, size_z, size_c, size_t, pixel_type)
757
-
758
- results = {}
759
- scoring = None
760
-
761
- # javabridge already set up from previous call to get_information()
762
- with bioformats.ImageReader(os.path.join(input_dir, filename)) as reader:
763
- start_x, start_y = 0, 0
764
-
765
- while start_x < size_x:
766
- while start_y < size_y:
767
- print(start_x, start_y)
768
- region_XYWH = (start_x, start_y, min(region_size, size_x - start_x), min(region_size, size_y - start_y))
769
- region = reader.read(XYWH=region_XYWH, rescale=rescale)
770
- img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
771
-
772
- region_modalities, region_scoring = infer_modalities(img, tile_size, model_dir)
773
- if region_scoring is not None:
774
- if scoring is None:
775
- scoring = {
776
- 'num_pos': region_scoring['num_pos'],
777
- 'num_neg': region_scoring['num_neg'],
778
- }
779
- else:
780
- scoring['num_pos'] += region_scoring['num_pos']
781
- scoring['num_neg'] += region_scoring['num_neg']
782
-
783
- for name, img in region_modalities.items():
784
- if name not in results:
785
- results[name] = np.zeros((size_y, size_x, 3), dtype=np.uint8)
786
- results[name][region_XYWH[1]: region_XYWH[1] + region_XYWH[3],
787
- region_XYWH[0]: region_XYWH[0] + region_XYWH[2]] = np.array(img)
788
- start_y += region_size
789
- start_y = 0
790
- start_x += region_size
791
-
792
- # write_results_to_pickle_file(os.path.join(results_dir, "results.pickle"), results)
793
- # read_results_from_pickle_file(os.path.join(results_dir, "results.pickle"))
794
-
795
- for name, img in results.items():
796
- write_big_tiff_file(os.path.join(results_dir, f'{basename}_{name}.ome.tiff'), img, tile_size)
797
-
798
- if scoring is not None:
799
- scoring['num_total'] = scoring['num_pos'] + scoring['num_neg']
800
- scoring['percent_pos'] = round(scoring['num_pos'] / scoring['num_total'] * 100, 1) if scoring['num_pos'] > 0 else 0
801
- with open(os.path.join(results_dir, f'{basename}.json'), 'w') as f:
802
- json.dump(scoring, f, indent=2)
803
-
804
- javabridge.kill_vm()
805
-
806
-
807
- def get_wsi_resolution(filename):
808
- """
809
- Use OpenSlide to get the resolution (magnification) of the slide
810
- and the corresponding tile size to use by default for DeepLIIF.
811
- If it cannot be found, return (None, None) instead.
812
-
813
- Parameters
814
- ----------
815
- filename : str
816
- Full path to the file.
817
-
818
- Returns
819
- -------
820
- str :
821
- Magnification (objective power) as found by OpenSlide.
822
- int :
823
- Corresponding tile size for DeepLIIF.
824
- """
825
- try:
826
- image = openslide.OpenSlide(filename)
827
- mag = image.properties.get(openslide.PROPERTY_NAME_OBJECTIVE_POWER)
828
- tile_size = round((float(mag) / 40) * 512)
829
- return mag, tile_size
830
- except Exception as e:
831
- return None, None
832
-
833
-
834
- def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False):
835
- """
836
- Perform inference on a slide and get the results individual cell data.
837
-
838
- Parameters
839
- ----------
840
- filename : str
841
- Full path to the file.
842
- model_dir : str
843
- Full path to the directory with the DeepLIIF model files.
844
- tile_size : int
845
- Size of tiles to extract and perform inference on.
846
- region_size : int
847
- Maximum size to split the slide for processing.
848
- version : int
849
- Version of cell data to return (3 or 4).
850
- print_log : bool
851
- Whether or not to print updates while processing.
852
-
853
- Returns
854
- -------
855
- dict :
856
- Individual cell data and associated values.
857
- """
858
-
859
- def print_info(*args):
860
- if print_log:
861
- print(*args, flush=True)
862
-
863
- resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
864
-
865
- size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
866
- rescale = (pixel_type != 'uint8')
867
- print_info('Info:', size_x, size_y, size_z, size_c, size_t, pixel_type)
868
-
869
- num_regions_x = math.ceil(size_x / region_size)
870
- num_regions_y = math.ceil(size_y / region_size)
871
- stride_x = math.ceil(size_x / num_regions_x)
872
- stride_y = math.ceil(size_y / num_regions_y)
873
- print_info('Strides:', stride_x, stride_y)
874
-
875
- data = None
876
- default_marker_thresh, count_marker_thresh = 0, 0
877
- default_size_thresh, count_size_thresh = 0, 0
878
-
879
- # javabridge already set up from previous call to get_information()
880
- with bioformats.ImageReader(filename) as reader:
881
- start_x, start_y = 0, 0
882
-
883
- while start_y < size_y:
884
- while start_x < size_x:
885
- region_XYWH = (start_x, start_y, min(stride_x, size_x-start_x), min(stride_y, size_y-start_y))
886
- print_info('Region:', region_XYWH)
887
-
888
- region = reader.read(XYWH=region_XYWH, rescale=rescale)
889
- print_info(region.shape, region.dtype)
890
- img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
891
- print_info(img.size, img.mode)
892
-
893
- images = inference(
894
- img,
895
- tile_size=tile_size,
896
- overlap_size=tile_size//16,
897
- model_path=model_dir,
898
- eager_mode=False,
899
- color_dapi=False,
900
- color_marker=False,
901
- opt=None,
902
- return_seg_intermediate=False,
903
- seg_only=True,
904
- )
905
- region_data = compute_cell_results(images['Seg'], images.get('Marker'), resolution, version=version)
906
-
907
- if start_x != 0 or start_y != 0:
908
- for i in range(len(region_data['cells'])):
909
- cell = decode_cell_data_v4(region_data['cells'][i]) if version == 4 else region_data['cells'][i]
910
- for j in range(2):
911
- cell['bbox'][j] = (cell['bbox'][j][0] + start_x, cell['bbox'][j][1] + start_y)
912
- cell['centroid'] = (cell['centroid'][0] + start_x, cell['centroid'][1] + start_y)
913
- for j in range(len(cell['boundary'])):
914
- cell['boundary'][j] = (cell['boundary'][j][0] + start_x, cell['boundary'][j][1] + start_y)
915
- region_data['cells'][i] = encode_cell_data_v4(cell) if version == 4 else cell
916
-
917
- if data is None:
918
- data = region_data
919
- else:
920
- data['cells'] += region_data['cells']
921
-
922
- if region_data['settings']['default_marker_thresh'] is not None and region_data['settings']['default_marker_thresh'] != 0:
923
- default_marker_thresh += region_data['settings']['default_marker_thresh']
924
- count_marker_thresh += 1
925
- if region_data['settings']['default_size_thresh'] != 0:
926
- default_size_thresh += region_data['settings']['default_size_thresh']
927
- count_size_thresh += 1
928
-
929
- start_x += stride_x
930
-
931
- start_x = 0
932
- start_y += stride_y
933
-
934
- javabridge.kill_vm()
935
-
936
- if count_marker_thresh == 0:
937
- count_marker_thresh = 1
938
- if count_size_thresh == 0:
939
- count_size_thresh = 1
940
- data['settings']['default_marker_thresh'] = round(default_marker_thresh / count_marker_thresh)
941
- data['settings']['default_size_thresh'] = round(default_size_thresh / count_size_thresh)
942
-
943
- return data