deepliif 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepliif/postprocessing.py +41 -13
- deepliif/util/__init__.py +6 -4
- {deepliif-1.2.0.dist-info → deepliif-1.2.2.dist-info}/METADATA +3 -3
- {deepliif-1.2.0.dist-info → deepliif-1.2.2.dist-info}/RECORD +8 -17
- deepliif/models/__init__ - different weighted.py +0 -762
- deepliif/models/__init__ - multiprocessing (failure).py +0 -980
- deepliif/models/__init__ - run_dask_multi dev.py +0 -943
- deepliif/models/__init__ - time gens.py +0 -792
- deepliif/models/__init__ - timings.py +0 -764
- deepliif/models/__init__ - weights, empty, zarr, tile count.py +0 -792
- deepliif/postprocessing__OLD__DELETE.py +0 -440
- deepliif/train.py +0 -280
- deepliif/util/util - modified tensor_to_pil.py +0 -255
- {deepliif-1.2.0.dist-info → deepliif-1.2.2.dist-info}/LICENSE.md +0 -0
- {deepliif-1.2.0.dist-info → deepliif-1.2.2.dist-info}/WHEEL +0 -0
- {deepliif-1.2.0.dist-info → deepliif-1.2.2.dist-info}/entry_points.txt +0 -0
- {deepliif-1.2.0.dist-info → deepliif-1.2.2.dist-info}/top_level.txt +0 -0
|
@@ -1,764 +0,0 @@
|
|
|
1
|
-
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
|
2
|
-
|
|
3
|
-
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
|
4
|
-
You need to implement the following five functions:
|
|
5
|
-
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
|
6
|
-
-- <set_input>: unpack data from dataset and apply preprocessing.
|
|
7
|
-
-- <forward>: produce intermediate results.
|
|
8
|
-
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
|
9
|
-
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
|
10
|
-
|
|
11
|
-
In the function <__init__>, you need to define four lists:
|
|
12
|
-
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
|
13
|
-
-- self.model_names (str list): define networks used in our training.
|
|
14
|
-
-- self.visual_names (str list): specify the images that you want to display and save.
|
|
15
|
-
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See CycleGAN_model.py for an usage.
|
|
16
|
-
|
|
17
|
-
Now you can use the model class by specifying flag '--model dummy'.
|
|
18
|
-
See our template model class 'template_model.py' for more details.
|
|
19
|
-
"""
|
|
20
|
-
import base64
|
|
21
|
-
import os
|
|
22
|
-
import itertools
|
|
23
|
-
import importlib
|
|
24
|
-
from functools import lru_cache
|
|
25
|
-
from io import BytesIO
|
|
26
|
-
import json
|
|
27
|
-
import math
|
|
28
|
-
|
|
29
|
-
import time
|
|
30
|
-
time_init = 0
|
|
31
|
-
time_empty = 0
|
|
32
|
-
time_transform = 0
|
|
33
|
-
time_compute = 0
|
|
34
|
-
|
|
35
|
-
import requests
|
|
36
|
-
import torch
|
|
37
|
-
from PIL import Image
|
|
38
|
-
Image.MAX_IMAGE_PIXELS = None
|
|
39
|
-
|
|
40
|
-
import numpy as np
|
|
41
|
-
from dask import delayed, compute
|
|
42
|
-
import openslide
|
|
43
|
-
|
|
44
|
-
from deepliif.util import *
|
|
45
|
-
from deepliif.util.util import tensor_to_pil
|
|
46
|
-
from deepliif.data import transform
|
|
47
|
-
from deepliif.postprocessing import compute_final_results, compute_cell_results
|
|
48
|
-
from deepliif.postprocessing import encode_cell_data_v4, decode_cell_data_v4
|
|
49
|
-
from deepliif.options import Options, print_options
|
|
50
|
-
|
|
51
|
-
from .base_model import BaseModel
|
|
52
|
-
|
|
53
|
-
# import for init purpose, not used in this script
|
|
54
|
-
from .DeepLIIF_model import DeepLIIFModel
|
|
55
|
-
from .DeepLIIFExt_model import DeepLIIFExtModel
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
@lru_cache
|
|
59
|
-
def get_opt(model_dir, mode='test'):
|
|
60
|
-
"""
|
|
61
|
-
mode: test or train, currently only functions used for inference utilize get_opt so it
|
|
62
|
-
defaults to test
|
|
63
|
-
"""
|
|
64
|
-
if mode == 'train':
|
|
65
|
-
opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
|
|
66
|
-
elif mode == 'test':
|
|
67
|
-
try:
|
|
68
|
-
opt = Options(path_file=os.path.join(model_dir,'test_opt.txt'), mode=mode)
|
|
69
|
-
except:
|
|
70
|
-
opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
|
|
71
|
-
opt.use_dp = False
|
|
72
|
-
opt.gpu_ids = list(range(torch.cuda.device_count()))
|
|
73
|
-
return opt
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def find_model_using_name(model_name):
|
|
77
|
-
"""Import the module "models/[model_name]_model.py".
|
|
78
|
-
|
|
79
|
-
In the file, the class called DatasetNameModel() will
|
|
80
|
-
be instantiated. It has to be a subclass of BaseModel,
|
|
81
|
-
and it is case-insensitive.
|
|
82
|
-
"""
|
|
83
|
-
model_filename = "deepliif.models." + model_name + "_model"
|
|
84
|
-
modellib = importlib.import_module(model_filename)
|
|
85
|
-
model = None
|
|
86
|
-
target_model_name = model_name.replace('_', '') + 'model'
|
|
87
|
-
for name, cls in modellib.__dict__.items():
|
|
88
|
-
if name.lower() == target_model_name.lower() \
|
|
89
|
-
and issubclass(cls, BaseModel):
|
|
90
|
-
model = cls
|
|
91
|
-
|
|
92
|
-
if model is None:
|
|
93
|
-
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (
|
|
94
|
-
model_filename, target_model_name))
|
|
95
|
-
exit(0)
|
|
96
|
-
|
|
97
|
-
return model
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
def get_option_setter(model_name):
|
|
101
|
-
"""Return the static method <modify_commandline_options> of the model class."""
|
|
102
|
-
model_class = find_model_using_name(model_name)
|
|
103
|
-
return model_class.modify_commandline_options
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
def create_model(opt):
|
|
107
|
-
"""Create a model given the option.
|
|
108
|
-
|
|
109
|
-
This function warps the class CustomDatasetDataLoader.
|
|
110
|
-
This is the main interface between this package and 'train.py'/'test.py'
|
|
111
|
-
|
|
112
|
-
Example:
|
|
113
|
-
>>> from deepliif.models import create_model
|
|
114
|
-
>>> model = create_model(opt)
|
|
115
|
-
"""
|
|
116
|
-
model = find_model_using_name(opt.model)
|
|
117
|
-
instance = model(opt)
|
|
118
|
-
print("model [%s] was created" % type(instance).__name__)
|
|
119
|
-
return instance
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
def load_torchscript_model(model_pt_path, device):
|
|
123
|
-
net = torch.jit.load(model_pt_path, map_location=device)
|
|
124
|
-
net = disable_batchnorm_tracking_stats(net)
|
|
125
|
-
net.eval()
|
|
126
|
-
return net
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
def load_eager_models(opt, devices=None):
|
|
131
|
-
# create a model given model and other options
|
|
132
|
-
model = create_model(opt)
|
|
133
|
-
# regular setup: load and print networks; create schedulers
|
|
134
|
-
model.setup(opt)
|
|
135
|
-
|
|
136
|
-
nets = {}
|
|
137
|
-
if devices:
|
|
138
|
-
model_names = list(devices.keys())
|
|
139
|
-
else:
|
|
140
|
-
model_names = model.model_names
|
|
141
|
-
|
|
142
|
-
for name in model_names:#model.model_names:
|
|
143
|
-
if isinstance(name, str):
|
|
144
|
-
if '_' in name:
|
|
145
|
-
net = getattr(model, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
|
|
146
|
-
else:
|
|
147
|
-
net = getattr(model, 'net' + name)
|
|
148
|
-
|
|
149
|
-
if opt.phase != 'train':
|
|
150
|
-
net.eval()
|
|
151
|
-
net = disable_batchnorm_tracking_stats(net)
|
|
152
|
-
|
|
153
|
-
# SDG models when loaded are still DP.. not sure why
|
|
154
|
-
if isinstance(net, torch.nn.DataParallel):
|
|
155
|
-
net = net.module
|
|
156
|
-
|
|
157
|
-
nets[name] = net
|
|
158
|
-
if devices:
|
|
159
|
-
nets[name].to(devices[name])
|
|
160
|
-
|
|
161
|
-
return nets
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
@lru_cache
|
|
165
|
-
def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
|
|
166
|
-
"""
|
|
167
|
-
Init DeepLIIF networks so that every net in
|
|
168
|
-
the same group is deployed on the same GPU
|
|
169
|
-
|
|
170
|
-
opt_args: to overwrite opt arguments in train_opt.txt, typically used in inference stage
|
|
171
|
-
for example, opt_args={'phase':'test'}
|
|
172
|
-
"""
|
|
173
|
-
if opt is None:
|
|
174
|
-
opt = get_opt(model_dir, mode=phase)
|
|
175
|
-
opt.use_dp = False
|
|
176
|
-
|
|
177
|
-
if opt.model == 'DeepLIIF':
|
|
178
|
-
net_groups = [
|
|
179
|
-
('G1', 'G52'),
|
|
180
|
-
('G2', 'G53'),
|
|
181
|
-
('G3', 'G54'),
|
|
182
|
-
('G4', 'G55'),
|
|
183
|
-
('G51',)
|
|
184
|
-
]
|
|
185
|
-
elif opt.model in ['DeepLIIFExt','SDG']:
|
|
186
|
-
if opt.seg_gen:
|
|
187
|
-
net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
|
|
188
|
-
else:
|
|
189
|
-
net_groups = [(f'G_{i+1}',) for i in range(opt.modalities_no)]
|
|
190
|
-
elif opt.model == 'CycleGAN':
|
|
191
|
-
if opt.BtoA:
|
|
192
|
-
net_groups = [(f'GB_{i+1}',) for i in range(opt.modalities_no)]
|
|
193
|
-
else:
|
|
194
|
-
net_groups = [(f'GA_{i+1}',) for i in range(opt.modalities_no)]
|
|
195
|
-
else:
|
|
196
|
-
raise Exception(f'init_nets() not implemented for model {opt.model}')
|
|
197
|
-
|
|
198
|
-
number_of_gpus_all = torch.cuda.device_count()
|
|
199
|
-
number_of_gpus = min(len(opt.gpu_ids),number_of_gpus_all)
|
|
200
|
-
|
|
201
|
-
if number_of_gpus > 0:
|
|
202
|
-
mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
|
|
203
|
-
chunks = [itertools.chain.from_iterable(c) for c in chunker(net_groups, number_of_gpus)]
|
|
204
|
-
# chunks = chunks[1:]
|
|
205
|
-
devices = {n: torch.device(f'cuda:{mapping_gpu_ids[i]}') for i, g in enumerate(chunks) for n in g}
|
|
206
|
-
# devices = {n: torch.device(f'cuda:{i}') for i, g in enumerate(chunks) for n in g}
|
|
207
|
-
else:
|
|
208
|
-
devices = {n: torch.device('cpu') for n in itertools.chain.from_iterable(net_groups)}
|
|
209
|
-
|
|
210
|
-
if eager_mode:
|
|
211
|
-
return load_eager_models(opt, devices)
|
|
212
|
-
|
|
213
|
-
return {
|
|
214
|
-
n: load_torchscript_model(os.path.join(model_dir, f'{n}.pt'), device=d)
|
|
215
|
-
for n, d in devices.items()
|
|
216
|
-
}
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
def compute_overlap(img_size, tile_size):
|
|
220
|
-
w, h = img_size
|
|
221
|
-
if round(w / tile_size) == 1 and round(h / tile_size) == 1:
|
|
222
|
-
return 0
|
|
223
|
-
|
|
224
|
-
return tile_size // 4
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
def run_torchserve(img, model_path=None, eager_mode=False, opt=None, seg_only=False):
|
|
228
|
-
"""
|
|
229
|
-
eager_mode: not used in this function; put in place to be consistent with run_dask
|
|
230
|
-
so that run_wrapper() could call either this function or run_dask with
|
|
231
|
-
same syntax
|
|
232
|
-
opt: same as eager_mode
|
|
233
|
-
seg_only: same as eager_mode
|
|
234
|
-
"""
|
|
235
|
-
buffer = BytesIO()
|
|
236
|
-
torch.save(transform(img.resize((opt.scale_size, opt.scale_size))), buffer)
|
|
237
|
-
|
|
238
|
-
torchserve_host = os.getenv('TORCHSERVE_HOST', 'http://localhost')
|
|
239
|
-
res = requests.post(
|
|
240
|
-
f'{torchserve_host}/wfpredict/deepliif',
|
|
241
|
-
json={'img': base64.b64encode(buffer.getvalue()).decode('utf-8')}
|
|
242
|
-
)
|
|
243
|
-
|
|
244
|
-
res.raise_for_status()
|
|
245
|
-
|
|
246
|
-
def deserialize_tensor(bs):
|
|
247
|
-
return torch.load(BytesIO(base64.b64decode(bs.encode())), map_location=torch.device('cpu'))
|
|
248
|
-
|
|
249
|
-
return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
|
|
253
|
-
model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
|
|
254
|
-
tstart = time.time()
|
|
255
|
-
nets = init_nets(model_dir, eager_mode, opt)
|
|
256
|
-
tend = time.time()
|
|
257
|
-
global time_init
|
|
258
|
-
time_init += (tend - tstart)
|
|
259
|
-
use_dask = True if opt.norm != 'spectral' else False
|
|
260
|
-
|
|
261
|
-
tstart = time.time()
|
|
262
|
-
if opt.input_no > 1 or opt.model == 'SDG':
|
|
263
|
-
l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
|
|
264
|
-
ts = torch.cat(l_ts, dim=1)
|
|
265
|
-
else:
|
|
266
|
-
ts = transform(img.resize((opt.scale_size, opt.scale_size)))
|
|
267
|
-
tend = time.time()
|
|
268
|
-
global time_transform
|
|
269
|
-
time_transform += (tend - tstart)
|
|
270
|
-
|
|
271
|
-
if use_dask:
|
|
272
|
-
@delayed
|
|
273
|
-
def forward(input, model):
|
|
274
|
-
with torch.no_grad():
|
|
275
|
-
return model(input.to(next(model.parameters()).device))
|
|
276
|
-
else: # some train settings like spectral norm some how in inference mode is not compatible with dask
|
|
277
|
-
def forward(input, model):
|
|
278
|
-
with torch.no_grad():
|
|
279
|
-
return model(input.to(next(model.parameters()).device))
|
|
280
|
-
|
|
281
|
-
if opt.model == 'DeepLIIF':
|
|
282
|
-
tstart = time.time()
|
|
283
|
-
weights = {
|
|
284
|
-
'G51': 0.25, # IHC
|
|
285
|
-
'G52': 0.25, # Hema
|
|
286
|
-
'G53': 0.25, # DAPI
|
|
287
|
-
'G54': 0.00, # Lap2
|
|
288
|
-
'G55': 0.25, # Marker
|
|
289
|
-
}
|
|
290
|
-
|
|
291
|
-
seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
|
|
292
|
-
if seg_only:
|
|
293
|
-
seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0}
|
|
294
|
-
|
|
295
|
-
lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
|
|
296
|
-
if 'G4' not in seg_map:
|
|
297
|
-
lazy_gens['G4'] = forward(ts, nets['G4'])
|
|
298
|
-
gens = compute(lazy_gens)[0]
|
|
299
|
-
|
|
300
|
-
lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
|
|
301
|
-
if not seg_only or weights['G51'] != 0:
|
|
302
|
-
lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
|
|
303
|
-
segs = compute(lazy_segs)[0]
|
|
304
|
-
|
|
305
|
-
seg = torch.stack([torch.mul(segs[k], weights[k]) for k in segs.keys()]).sum(dim=0)
|
|
306
|
-
|
|
307
|
-
if seg_only:
|
|
308
|
-
res = {'G4': tensor_to_pil(gens['G4'])} if 'G4' in gens else {}
|
|
309
|
-
else:
|
|
310
|
-
res = {k: tensor_to_pil(v) for k, v in gens.items()}
|
|
311
|
-
res.update({k: tensor_to_pil(v) for k, v in segs.items()})
|
|
312
|
-
res['G5'] = tensor_to_pil(seg)
|
|
313
|
-
tend = time.time()
|
|
314
|
-
global time_compute
|
|
315
|
-
time_compute += (tend - tstart)
|
|
316
|
-
|
|
317
|
-
return res
|
|
318
|
-
elif opt.model in ['DeepLIIFExt','SDG','CycleGAN']:
|
|
319
|
-
if opt.model == 'CycleGAN':
|
|
320
|
-
seg_map = {f'GB_{i+1}':None for i in range(opt.modalities_no)} if opt.BtoA else {f'GA_{i+1}':None for i in range(opt.modalities_no)}
|
|
321
|
-
else:
|
|
322
|
-
seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
|
|
323
|
-
|
|
324
|
-
if use_dask:
|
|
325
|
-
lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
|
|
326
|
-
gens = compute(lazy_gens)[0]
|
|
327
|
-
else:
|
|
328
|
-
gens = {k: forward(ts, nets[k]) for k in seg_map}
|
|
329
|
-
|
|
330
|
-
res = {k: tensor_to_pil(v) for k, v in gens.items()}
|
|
331
|
-
|
|
332
|
-
if opt.seg_gen:
|
|
333
|
-
if use_dask:
|
|
334
|
-
lazy_segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
|
|
335
|
-
segs = compute(lazy_segs)[0]
|
|
336
|
-
else:
|
|
337
|
-
segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
|
|
338
|
-
res.update({k: tensor_to_pil(v) for k, v in segs.items()})
|
|
339
|
-
|
|
340
|
-
return res
|
|
341
|
-
else:
|
|
342
|
-
raise Exception(f'run_dask() not fully implemented for {opt.model}')
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
def is_empty(tile):
|
|
346
|
-
thresh = 15
|
|
347
|
-
if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
|
|
348
|
-
return all([True if np.max(image_variance_rgb(t)) < thresh else False for t in tile])
|
|
349
|
-
else:
|
|
350
|
-
return True if np.max(image_variance_rgb(tile)) < thresh else False
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None, seg_only=False):
|
|
354
|
-
if opt.model == 'DeepLIIF':
|
|
355
|
-
#if is_empty(tile):
|
|
356
|
-
tstart = time.time()
|
|
357
|
-
empty = is_empty(tile)
|
|
358
|
-
tend = time.time()
|
|
359
|
-
global time_empty
|
|
360
|
-
time_empty += (tend - tstart)
|
|
361
|
-
if empty:
|
|
362
|
-
if seg_only:
|
|
363
|
-
return {
|
|
364
|
-
'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
|
|
365
|
-
'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
366
|
-
}
|
|
367
|
-
else :
|
|
368
|
-
return {
|
|
369
|
-
'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
|
|
370
|
-
'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
|
|
371
|
-
'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
372
|
-
'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
|
|
373
|
-
'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
374
|
-
'G51': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
375
|
-
'G52': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
376
|
-
'G53': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
377
|
-
'G54': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
378
|
-
'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
379
|
-
}
|
|
380
|
-
else:
|
|
381
|
-
return run_fn(tile, model_path, eager_mode, opt, seg_only)
|
|
382
|
-
elif opt.model in ['DeepLIIFExt', 'SDG']:
|
|
383
|
-
if is_empty(tile):
|
|
384
|
-
res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
|
|
385
|
-
res.update({'GS_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)})
|
|
386
|
-
return res
|
|
387
|
-
else:
|
|
388
|
-
return run_fn(tile, model_path, eager_mode, opt)
|
|
389
|
-
elif opt.model in ['CycleGAN']:
|
|
390
|
-
if is_empty(tile):
|
|
391
|
-
net_names = ['GB_{i+1}' for i in range(opt.modalities_no)] if opt.BtoA else [f'GA_{i+1}' for i in range(opt.modalities_no)]
|
|
392
|
-
res = {net_name: Image.new(mode='RGB', size=(512, 512)) for net_name in net_names}
|
|
393
|
-
return res
|
|
394
|
-
else:
|
|
395
|
-
return run_fn(tile, model_path, eager_mode, opt)
|
|
396
|
-
else:
|
|
397
|
-
raise Exception(f'run_wrapper() not implemented for model {opt.model}')
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
|
|
401
|
-
eager_mode=False, color_dapi=False, color_marker=False, opt=None,
|
|
402
|
-
return_seg_intermediate=False, seg_only=False):
|
|
403
|
-
if not opt:
|
|
404
|
-
opt = get_opt(model_path)
|
|
405
|
-
#print_options(opt)
|
|
406
|
-
|
|
407
|
-
run_fn = run_torchserve if use_torchserve else run_dask
|
|
408
|
-
|
|
409
|
-
if opt.model == 'SDG':
|
|
410
|
-
# SDG could have multiple input images/modalities, hence the input could be a rectangle.
|
|
411
|
-
# We split the input to get each modality image then create tiles for each set of input images.
|
|
412
|
-
w, h = int(img.width / opt.input_no), img.height
|
|
413
|
-
orig = [img.crop((w * i, 0, w * (i+1), h)) for i in range(opt.input_no)]
|
|
414
|
-
else:
|
|
415
|
-
# Otherwise expect a single input image, which is used directly.
|
|
416
|
-
orig = img
|
|
417
|
-
|
|
418
|
-
tiler = InferenceTiler(orig, tile_size, overlap_size)
|
|
419
|
-
for tile in tiler:
|
|
420
|
-
tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt, seg_only))
|
|
421
|
-
results = tiler.results()
|
|
422
|
-
|
|
423
|
-
if opt.model == 'DeepLIIF':
|
|
424
|
-
if seg_only:
|
|
425
|
-
images = {'Seg': results['G5']}
|
|
426
|
-
if 'G4' in results:
|
|
427
|
-
images.update({'Marker': results['G4']})
|
|
428
|
-
else:
|
|
429
|
-
images = {
|
|
430
|
-
'Hema': results['G1'],
|
|
431
|
-
'DAPI': results['G2'],
|
|
432
|
-
'Lap2': results['G3'],
|
|
433
|
-
'Marker': results['G4'],
|
|
434
|
-
'Seg': results['G5'],
|
|
435
|
-
}
|
|
436
|
-
|
|
437
|
-
if return_seg_intermediate and not seg_only:
|
|
438
|
-
images.update({'IHC_s':results['G51'],
|
|
439
|
-
'Hema_s':results['G52'],
|
|
440
|
-
'DAPI_s':results['G53'],
|
|
441
|
-
'Lap2_s':results['G54'],
|
|
442
|
-
'Marker_s':results['G55'],})
|
|
443
|
-
|
|
444
|
-
if color_dapi and not seg_only:
|
|
445
|
-
matrix = ( 0, 0, 0, 0,
|
|
446
|
-
299/1000, 587/1000, 114/1000, 0,
|
|
447
|
-
299/1000, 587/1000, 114/1000, 0)
|
|
448
|
-
images['DAPI'] = images['DAPI'].convert('RGB', matrix)
|
|
449
|
-
if color_marker and not seg_only:
|
|
450
|
-
matrix = (299/1000, 587/1000, 114/1000, 0,
|
|
451
|
-
299/1000, 587/1000, 114/1000, 0,
|
|
452
|
-
0, 0, 0, 0)
|
|
453
|
-
images['Marker'] = images['Marker'].convert('RGB', matrix)
|
|
454
|
-
return images
|
|
455
|
-
|
|
456
|
-
elif opt.model == 'DeepLIIFExt':
|
|
457
|
-
images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
|
|
458
|
-
if opt.seg_gen:
|
|
459
|
-
images.update({f'Seg{i}': results[f'GS_{i}'] for i in range(1, opt.modalities_no + 1)})
|
|
460
|
-
return images
|
|
461
|
-
|
|
462
|
-
elif opt.model == 'SDG':
|
|
463
|
-
images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
|
|
464
|
-
return images
|
|
465
|
-
|
|
466
|
-
else:
|
|
467
|
-
#raise Exception(f'inference() not implemented for model {opt.model}')
|
|
468
|
-
return results # return result images with default key names (i.e., net names)
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='default', marker_thresh=None, size_thresh_upper=None):
|
|
472
|
-
if model == 'DeepLIIF':
|
|
473
|
-
resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
|
|
474
|
-
overlay, refined, scoring = compute_final_results(
|
|
475
|
-
orig, images['Seg'], images.get('Marker'), resolution,
|
|
476
|
-
size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
|
|
477
|
-
processed_images = {}
|
|
478
|
-
processed_images['SegOverlaid'] = Image.fromarray(overlay)
|
|
479
|
-
processed_images['SegRefined'] = Image.fromarray(refined)
|
|
480
|
-
return processed_images, scoring
|
|
481
|
-
|
|
482
|
-
elif model in ['DeepLIIFExt','SDG']:
|
|
483
|
-
resolution = '40x' if tile_size > 768 else ('20x' if tile_size > 384 else '10x')
|
|
484
|
-
processed_images = {}
|
|
485
|
-
scoring = {}
|
|
486
|
-
for img_name in list(images.keys()):
|
|
487
|
-
if 'Seg' in img_name:
|
|
488
|
-
seg_img = images[img_name]
|
|
489
|
-
overlay, refined, score = compute_final_results(
|
|
490
|
-
orig, images[img_name], None, resolution,
|
|
491
|
-
size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
|
|
492
|
-
|
|
493
|
-
processed_images[img_name + '_Overlaid'] = Image.fromarray(overlay)
|
|
494
|
-
processed_images[img_name + '_Refined'] = Image.fromarray(refined)
|
|
495
|
-
scoring[img_name] = score
|
|
496
|
-
return processed_images, scoring
|
|
497
|
-
|
|
498
|
-
else:
|
|
499
|
-
raise Exception(f'postprocess() not implemented for model {model}')
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
def infer_modalities(img, tile_size, model_dir, eager_mode=False,
|
|
503
|
-
color_dapi=False, color_marker=False, opt=None,
|
|
504
|
-
return_seg_intermediate=False):
|
|
505
|
-
"""
|
|
506
|
-
This function is used to infer modalities for the given image using a trained model.
|
|
507
|
-
:param img: The input image.
|
|
508
|
-
:param tile_size: The tile size.
|
|
509
|
-
:param model_dir: The directory containing serialized model files.
|
|
510
|
-
:return: The inferred modalities and the segmentation mask.
|
|
511
|
-
"""
|
|
512
|
-
if opt is None:
|
|
513
|
-
opt = get_opt(model_dir)
|
|
514
|
-
opt.use_dp = False
|
|
515
|
-
#print_options(opt)
|
|
516
|
-
|
|
517
|
-
# for those with multiple input modalities, find the correct size to calculate overlap_size
|
|
518
|
-
input_no = opt.input_no if hasattr(opt, 'input_no') else 1
|
|
519
|
-
img_size = (img.size[0] / input_no, img.size[1]) # (width, height)
|
|
520
|
-
|
|
521
|
-
images = inference(
|
|
522
|
-
img,
|
|
523
|
-
tile_size=tile_size,
|
|
524
|
-
#overlap_size=compute_overlap(img_size, tile_size),
|
|
525
|
-
overlap_size=tile_size//16,
|
|
526
|
-
model_path=model_dir,
|
|
527
|
-
eager_mode=eager_mode,
|
|
528
|
-
color_dapi=color_dapi,
|
|
529
|
-
color_marker=color_marker,
|
|
530
|
-
opt=opt,
|
|
531
|
-
return_seg_intermediate=return_seg_intermediate
|
|
532
|
-
)
|
|
533
|
-
|
|
534
|
-
if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
|
|
535
|
-
post_images, scoring = postprocess(img, images, tile_size, opt.model)
|
|
536
|
-
images = {**images, **post_images}
|
|
537
|
-
return images, scoring
|
|
538
|
-
else:
|
|
539
|
-
return images, None
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000):
|
|
543
|
-
"""
|
|
544
|
-
This function infers modalities and segmentation mask for the given WSI image. It
|
|
545
|
-
|
|
546
|
-
:param input_dir: The directory containing the WSI.
|
|
547
|
-
:param filename: The WSI name.
|
|
548
|
-
:param output_dir: The directory for saving the inferred modalities.
|
|
549
|
-
:param model_dir: The directory containing the serialized model files.
|
|
550
|
-
:param tile_size: The tile size.
|
|
551
|
-
:param region_size: The size of each individual region to be processed at once.
|
|
552
|
-
:return:
|
|
553
|
-
"""
|
|
554
|
-
basename, _ = os.path.splitext(filename)
|
|
555
|
-
results_dir = os.path.join(output_dir, basename)
|
|
556
|
-
if not os.path.exists(results_dir):
|
|
557
|
-
os.makedirs(results_dir)
|
|
558
|
-
size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(os.path.join(input_dir, filename))
|
|
559
|
-
rescale = (pixel_type != 'uint8')
|
|
560
|
-
print(filename, size_x, size_y, size_z, size_c, size_t, pixel_type)
|
|
561
|
-
|
|
562
|
-
results = {}
|
|
563
|
-
scoring = None
|
|
564
|
-
|
|
565
|
-
# javabridge already set up from previous call to get_information()
|
|
566
|
-
with bioformats.ImageReader(os.path.join(input_dir, filename)) as reader:
|
|
567
|
-
start_x, start_y = 0, 0
|
|
568
|
-
|
|
569
|
-
while start_x < size_x:
|
|
570
|
-
while start_y < size_y:
|
|
571
|
-
print(start_x, start_y)
|
|
572
|
-
region_XYWH = (start_x, start_y, min(region_size, size_x - start_x), min(region_size, size_y - start_y))
|
|
573
|
-
region = reader.read(XYWH=region_XYWH, rescale=rescale)
|
|
574
|
-
img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
|
|
575
|
-
|
|
576
|
-
region_modalities, region_scoring = infer_modalities(img, tile_size, model_dir)
|
|
577
|
-
if region_scoring is not None:
|
|
578
|
-
if scoring is None:
|
|
579
|
-
scoring = {
|
|
580
|
-
'num_pos': region_scoring['num_pos'],
|
|
581
|
-
'num_neg': region_scoring['num_neg'],
|
|
582
|
-
}
|
|
583
|
-
else:
|
|
584
|
-
scoring['num_pos'] += region_scoring['num_pos']
|
|
585
|
-
scoring['num_neg'] += region_scoring['num_neg']
|
|
586
|
-
|
|
587
|
-
for name, img in region_modalities.items():
|
|
588
|
-
if name not in results:
|
|
589
|
-
results[name] = np.zeros((size_y, size_x, 3), dtype=np.uint8)
|
|
590
|
-
results[name][region_XYWH[1]: region_XYWH[1] + region_XYWH[3],
|
|
591
|
-
region_XYWH[0]: region_XYWH[0] + region_XYWH[2]] = np.array(img)
|
|
592
|
-
start_y += region_size
|
|
593
|
-
start_y = 0
|
|
594
|
-
start_x += region_size
|
|
595
|
-
|
|
596
|
-
# write_results_to_pickle_file(os.path.join(results_dir, "results.pickle"), results)
|
|
597
|
-
# read_results_from_pickle_file(os.path.join(results_dir, "results.pickle"))
|
|
598
|
-
|
|
599
|
-
for name, img in results.items():
|
|
600
|
-
write_big_tiff_file(os.path.join(results_dir, f'{basename}_{name}.ome.tiff'), img, tile_size)
|
|
601
|
-
|
|
602
|
-
if scoring is not None:
|
|
603
|
-
scoring['num_total'] = scoring['num_pos'] + scoring['num_neg']
|
|
604
|
-
scoring['percent_pos'] = round(scoring['num_pos'] / scoring['num_total'] * 100, 1) if scoring['num_pos'] > 0 else 0
|
|
605
|
-
with open(os.path.join(results_dir, f'{basename}.json'), 'w') as f:
|
|
606
|
-
json.dump(scoring, f, indent=2)
|
|
607
|
-
|
|
608
|
-
javabridge.kill_vm()
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
def get_wsi_resolution(filename):
|
|
612
|
-
"""
|
|
613
|
-
Use OpenSlide to get the resolution (magnification) of the slide
|
|
614
|
-
and the corresponding tile size to use by default for DeepLIIF.
|
|
615
|
-
If it cannot be found, return (None, None) instead.
|
|
616
|
-
|
|
617
|
-
Parameters
|
|
618
|
-
----------
|
|
619
|
-
filename : str
|
|
620
|
-
Full path to the file.
|
|
621
|
-
|
|
622
|
-
Returns
|
|
623
|
-
-------
|
|
624
|
-
str :
|
|
625
|
-
Magnification (objective power) as found by OpenSlide.
|
|
626
|
-
int :
|
|
627
|
-
Corresponding tile size for DeepLIIF.
|
|
628
|
-
"""
|
|
629
|
-
try:
|
|
630
|
-
image = openslide.OpenSlide(filename)
|
|
631
|
-
mag = image.properties.get(openslide.PROPERTY_NAME_OBJECTIVE_POWER)
|
|
632
|
-
tile_size = round((float(mag) / 40) * 512)
|
|
633
|
-
return mag, tile_size
|
|
634
|
-
except Exception as e:
|
|
635
|
-
return None, None
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False):
|
|
639
|
-
"""
|
|
640
|
-
Perform inference on a slide and get the results individual cell data.
|
|
641
|
-
|
|
642
|
-
Parameters
|
|
643
|
-
----------
|
|
644
|
-
filename : str
|
|
645
|
-
Full path to the file.
|
|
646
|
-
model_dir : str
|
|
647
|
-
Full path to the directory with the DeepLIIF model files.
|
|
648
|
-
tile_size : int
|
|
649
|
-
Size of tiles to extract and perform inference on.
|
|
650
|
-
region_size : int
|
|
651
|
-
Maximum size to split the slide for processing.
|
|
652
|
-
version : int
|
|
653
|
-
Version of cell data to return (3 or 4).
|
|
654
|
-
print_log : bool
|
|
655
|
-
Whether or not to print updates while processing.
|
|
656
|
-
|
|
657
|
-
Returns
|
|
658
|
-
-------
|
|
659
|
-
dict :
|
|
660
|
-
Individual cell data and associated values.
|
|
661
|
-
"""
|
|
662
|
-
|
|
663
|
-
def print_info(*args):
|
|
664
|
-
if print_log:
|
|
665
|
-
print(*args, flush=True)
|
|
666
|
-
|
|
667
|
-
resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
|
|
668
|
-
|
|
669
|
-
size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
|
|
670
|
-
rescale = (pixel_type != 'uint8')
|
|
671
|
-
print_info('Info:', size_x, size_y, size_z, size_c, size_t, pixel_type)
|
|
672
|
-
|
|
673
|
-
num_regions_x = math.ceil(size_x / region_size)
|
|
674
|
-
num_regions_y = math.ceil(size_y / region_size)
|
|
675
|
-
stride_x = math.ceil(size_x / num_regions_x)
|
|
676
|
-
stride_y = math.ceil(size_y / num_regions_y)
|
|
677
|
-
print_info('Strides:', stride_x, stride_y)
|
|
678
|
-
|
|
679
|
-
data = None
|
|
680
|
-
default_marker_thresh, count_marker_thresh = 0, 0
|
|
681
|
-
default_size_thresh, count_size_thresh = 0, 0
|
|
682
|
-
|
|
683
|
-
time_inference = 0
|
|
684
|
-
time_postprocess = 0
|
|
685
|
-
|
|
686
|
-
# javabridge already set up from previous call to get_information()
|
|
687
|
-
with bioformats.ImageReader(filename) as reader:
|
|
688
|
-
start_x, start_y = 0, 0
|
|
689
|
-
|
|
690
|
-
while start_y < size_y:
|
|
691
|
-
while start_x < size_x:
|
|
692
|
-
region_XYWH = (start_x, start_y, min(stride_x, size_x-start_x), min(stride_y, size_y-start_y))
|
|
693
|
-
print_info('Region:', region_XYWH)
|
|
694
|
-
|
|
695
|
-
region = reader.read(XYWH=region_XYWH, rescale=rescale)
|
|
696
|
-
print_info(region.shape, region.dtype)
|
|
697
|
-
img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
|
|
698
|
-
print_info(img.size, img.mode)
|
|
699
|
-
|
|
700
|
-
tstart = time.time()
|
|
701
|
-
images = inference(
|
|
702
|
-
img,
|
|
703
|
-
tile_size=tile_size,
|
|
704
|
-
overlap_size=tile_size//16,
|
|
705
|
-
model_path=model_dir,
|
|
706
|
-
eager_mode=False,
|
|
707
|
-
color_dapi=False,
|
|
708
|
-
color_marker=False,
|
|
709
|
-
opt=None,
|
|
710
|
-
return_seg_intermediate=False,
|
|
711
|
-
seg_only=True,
|
|
712
|
-
)
|
|
713
|
-
tend = time.time()
|
|
714
|
-
time_inference += (tend - tstart)
|
|
715
|
-
tstart = time.time()
|
|
716
|
-
region_data = compute_cell_results(images['Seg'], images.get('Marker'), resolution, version=version)
|
|
717
|
-
tend = time.time()
|
|
718
|
-
time_postprocess += (tend - tstart)
|
|
719
|
-
|
|
720
|
-
if start_x != 0 or start_y != 0:
|
|
721
|
-
for i in range(len(region_data['cells'])):
|
|
722
|
-
cell = decode_cell_data_v4(region_data['cells'][i]) if version == 4 else region_data['cells'][i]
|
|
723
|
-
for j in range(2):
|
|
724
|
-
cell['bbox'][j] = (cell['bbox'][j][0] + start_x, cell['bbox'][j][1] + start_y)
|
|
725
|
-
cell['centroid'] = (cell['centroid'][0] + start_x, cell['centroid'][1] + start_y)
|
|
726
|
-
for j in range(len(cell['boundary'])):
|
|
727
|
-
cell['boundary'][j] = (cell['boundary'][j][0] + start_x, cell['boundary'][j][1] + start_y)
|
|
728
|
-
region_data['cells'][i] = encode_cell_data_v4(cell) if version == 4 else cell
|
|
729
|
-
|
|
730
|
-
if data is None:
|
|
731
|
-
data = region_data
|
|
732
|
-
else:
|
|
733
|
-
data['cells'] += region_data['cells']
|
|
734
|
-
|
|
735
|
-
if region_data['settings']['default_marker_thresh'] is not None and region_data['settings']['default_marker_thresh'] != 0:
|
|
736
|
-
default_marker_thresh += region_data['settings']['default_marker_thresh']
|
|
737
|
-
count_marker_thresh += 1
|
|
738
|
-
if region_data['settings']['default_size_thresh'] != 0:
|
|
739
|
-
default_size_thresh += region_data['settings']['default_size_thresh']
|
|
740
|
-
count_size_thresh += 1
|
|
741
|
-
|
|
742
|
-
start_x += stride_x
|
|
743
|
-
|
|
744
|
-
start_x = 0
|
|
745
|
-
start_y += stride_y
|
|
746
|
-
|
|
747
|
-
javabridge.kill_vm()
|
|
748
|
-
|
|
749
|
-
if count_marker_thresh == 0:
|
|
750
|
-
count_marker_thresh = 1
|
|
751
|
-
if count_size_thresh == 0:
|
|
752
|
-
count_size_thresh = 1
|
|
753
|
-
data['settings']['default_marker_thresh'] = round(default_marker_thresh / count_marker_thresh)
|
|
754
|
-
data['settings']['default_size_thresh'] = round(default_size_thresh / count_size_thresh)
|
|
755
|
-
|
|
756
|
-
global time_empty
|
|
757
|
-
print('Time init:', round(time_init, 3), flush=True)
|
|
758
|
-
print('Time empty:', round(time_empty, 3), flush=True)
|
|
759
|
-
print('Time transform:', round(time_transform, 3), flush=True)
|
|
760
|
-
print('Time compute:', round(time_compute, 3), flush=True)
|
|
761
|
-
print('Time inference:', round(time_inference, 3), flush=True)
|
|
762
|
-
print('Time postprocess:', round(time_postprocess, 3), flush=True)
|
|
763
|
-
|
|
764
|
-
return data
|