deepliif 1.1.6__py3-none-any.whl → 1.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli.py +76 -102
- deepliif/data/aligned_dataset.py +33 -7
- deepliif/models/DeepLIIFExt_model.py +297 -0
- deepliif/models/DeepLIIF_model.py +10 -5
- deepliif/models/__init__.py +262 -168
- deepliif/models/base_model.py +54 -8
- deepliif/options/__init__.py +101 -0
- deepliif/options/base_options.py +7 -6
- deepliif/postprocessing.py +285 -246
- {deepliif-1.1.6.dist-info → deepliif-1.1.8.dist-info}/METADATA +26 -12
- {deepliif-1.1.6.dist-info → deepliif-1.1.8.dist-info}/RECORD +15 -14
- {deepliif-1.1.6.dist-info → deepliif-1.1.8.dist-info}/LICENSE.md +0 -0
- {deepliif-1.1.6.dist-info → deepliif-1.1.8.dist-info}/WHEEL +0 -0
- {deepliif-1.1.6.dist-info → deepliif-1.1.8.dist-info}/entry_points.txt +0 -0
- {deepliif-1.1.6.dist-info → deepliif-1.1.8.dist-info}/top_level.txt +0 -0
deepliif/models/__init__.py
CHANGED
|
@@ -31,15 +31,32 @@ import numpy as np
|
|
|
31
31
|
from dask import delayed, compute
|
|
32
32
|
|
|
33
33
|
from deepliif.util import *
|
|
34
|
-
from deepliif.util.util import tensor_to_pil
|
|
34
|
+
from deepliif.util.util import tensor_to_pil, check_multi_scale
|
|
35
35
|
from deepliif.data import transform
|
|
36
|
-
from deepliif.postprocessing import
|
|
37
|
-
|
|
36
|
+
from deepliif.postprocessing import compute_results
|
|
37
|
+
from deepliif.options import Options, print_options
|
|
38
38
|
|
|
39
39
|
from .base_model import BaseModel
|
|
40
|
+
|
|
41
|
+
# import for init purpose, not used in this script
|
|
40
42
|
from .DeepLIIF_model import DeepLIIFModel
|
|
41
|
-
from .
|
|
43
|
+
from .DeepLIIFExt_model import DeepLIIFExtModel
|
|
44
|
+
|
|
42
45
|
|
|
46
|
+
@lru_cache
|
|
47
|
+
def get_opt(model_dir, mode='test'):
|
|
48
|
+
"""
|
|
49
|
+
mode: test or train, currently only functions used for inference utilize get_opt so it
|
|
50
|
+
defaults to test
|
|
51
|
+
"""
|
|
52
|
+
if mode == 'train':
|
|
53
|
+
opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
|
|
54
|
+
elif mode == 'test':
|
|
55
|
+
try:
|
|
56
|
+
opt = Options(path_file=os.path.join(model_dir,'test_opt.txt'), mode=mode)
|
|
57
|
+
except:
|
|
58
|
+
opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
|
|
59
|
+
return opt
|
|
43
60
|
|
|
44
61
|
def find_model_using_name(model_name):
|
|
45
62
|
"""Import the module "models/[model_name]_model.py".
|
|
@@ -94,87 +111,75 @@ def load_torchscript_model(model_pt_path, device):
|
|
|
94
111
|
return net
|
|
95
112
|
|
|
96
113
|
|
|
97
|
-
def read_model_params(file_addr):
|
|
98
|
-
with open(file_addr) as f:
|
|
99
|
-
lines = f.readlines()
|
|
100
|
-
param_dict = {}
|
|
101
|
-
for line in lines:
|
|
102
|
-
if ':' in line:
|
|
103
|
-
key = line.split(':')[0].strip()
|
|
104
|
-
val = line.split(':')[1].split('[')[0].strip()
|
|
105
|
-
param_dict[key] = val
|
|
106
|
-
print(param_dict)
|
|
107
|
-
return param_dict
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
def load_eager_models(model_dir, devices):
|
|
111
|
-
input_nc = 3
|
|
112
|
-
output_nc = 3
|
|
113
|
-
ngf = 64
|
|
114
|
-
norm = 'batch'
|
|
115
|
-
use_dropout = True
|
|
116
|
-
padding_type = 'zero'
|
|
117
|
-
|
|
118
|
-
files = os.listdir(model_dir)
|
|
119
|
-
for f in files:
|
|
120
|
-
if 'train_opt.txt' in f:
|
|
121
|
-
param_dict = read_model_params(os.path.join(model_dir, f))
|
|
122
|
-
input_nc = int(param_dict['input_nc'])
|
|
123
|
-
output_nc = int(param_dict['output_nc'])
|
|
124
|
-
ngf = int(param_dict['ngf'])
|
|
125
|
-
norm = param_dict['norm']
|
|
126
|
-
use_dropout = False if param_dict['no_dropout'] == 'True' else True
|
|
127
|
-
padding_type = param_dict['padding']
|
|
128
|
-
|
|
129
|
-
norm_layer = get_norm_layer(norm_type=norm)
|
|
130
114
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
map_location=devices[n]
|
|
137
|
-
))
|
|
138
|
-
nets[n] = disable_batchnorm_tracking_stats(net)
|
|
139
|
-
nets[n].eval()
|
|
140
|
-
|
|
141
|
-
for n in ['G51', 'G52', 'G53', 'G54', 'G55']:
|
|
142
|
-
net = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
|
143
|
-
net.load_state_dict(torch.load(
|
|
144
|
-
os.path.join(model_dir, f'latest_net_{n}.pth'),
|
|
145
|
-
map_location=devices[n]
|
|
146
|
-
))
|
|
147
|
-
nets[n] = disable_batchnorm_tracking_stats(net)
|
|
148
|
-
nets[n].eval()
|
|
115
|
+
def load_eager_models(opt, devices):
|
|
116
|
+
# create a model given model and other options
|
|
117
|
+
model = create_model(opt)
|
|
118
|
+
# regular setup: load and print networks; create schedulers
|
|
119
|
+
model.setup(opt)
|
|
149
120
|
|
|
121
|
+
nets = {}
|
|
122
|
+
for name in model.model_names:
|
|
123
|
+
if isinstance(name, str):
|
|
124
|
+
if '_' in name:
|
|
125
|
+
net = getattr(model, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
|
|
126
|
+
else:
|
|
127
|
+
net = getattr(model, 'net' + name)
|
|
128
|
+
|
|
129
|
+
if opt.phase != 'train':
|
|
130
|
+
net.eval()
|
|
131
|
+
net = disable_batchnorm_tracking_stats(net)
|
|
132
|
+
|
|
133
|
+
nets[name] = net
|
|
134
|
+
nets[name].to(devices[name])
|
|
135
|
+
|
|
150
136
|
return nets
|
|
151
137
|
|
|
152
138
|
|
|
153
139
|
@lru_cache
|
|
154
|
-
def init_nets(model_dir, eager_mode=False):
|
|
140
|
+
def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
|
|
155
141
|
"""
|
|
156
142
|
Init DeepLIIF networks so that every net in
|
|
157
143
|
the same group is deployed on the same GPU
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
(
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
144
|
+
|
|
145
|
+
opt_args: to overwrite opt arguments in train_opt.txt, typically used in inference stage
|
|
146
|
+
for example, opt_args={'phase':'test'}
|
|
147
|
+
"""
|
|
148
|
+
if opt is None:
|
|
149
|
+
opt = get_opt(model_dir, mode=phase)
|
|
150
|
+
opt.use_dp = False
|
|
151
|
+
print_options(opt)
|
|
152
|
+
|
|
153
|
+
if opt.model == 'DeepLIIF':
|
|
154
|
+
net_groups = [
|
|
155
|
+
('G1', 'G52'),
|
|
156
|
+
('G2', 'G53'),
|
|
157
|
+
('G3', 'G54'),
|
|
158
|
+
('G4', 'G55'),
|
|
159
|
+
('G51',)
|
|
160
|
+
]
|
|
161
|
+
elif opt.model == 'DeepLIIFExt':
|
|
162
|
+
if opt.seg_gen:
|
|
163
|
+
net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
|
|
164
|
+
else:
|
|
165
|
+
net_groups = [(f'G_{i+1}',) for i in range(opt.modalities_no)]
|
|
166
|
+
else:
|
|
167
|
+
raise Exception(f'init_nets() not implemented for model {opt.model}')
|
|
168
|
+
|
|
169
|
+
number_of_gpus_all = torch.cuda.device_count()
|
|
170
|
+
number_of_gpus = len(opt.gpu_ids)
|
|
171
|
+
print(number_of_gpus)
|
|
172
|
+
if number_of_gpus > 0:
|
|
173
|
+
mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
|
|
170
174
|
chunks = [itertools.chain.from_iterable(c) for c in chunker(net_groups, number_of_gpus)]
|
|
171
175
|
# chunks = chunks[1:]
|
|
172
|
-
devices = {n: torch.device(f'cuda:{i}') for i, g in enumerate(chunks) for n in g}
|
|
176
|
+
devices = {n: torch.device(f'cuda:{mapping_gpu_ids[i]}') for i, g in enumerate(chunks) for n in g}
|
|
177
|
+
# devices = {n: torch.device(f'cuda:{i}') for i, g in enumerate(chunks) for n in g}
|
|
173
178
|
else:
|
|
174
179
|
devices = {n: torch.device('cpu') for n in itertools.chain.from_iterable(net_groups)}
|
|
175
180
|
|
|
176
181
|
if eager_mode:
|
|
177
|
-
return load_eager_models(
|
|
182
|
+
return load_eager_models(opt, devices)
|
|
178
183
|
|
|
179
184
|
return {
|
|
180
185
|
n: load_torchscript_model(os.path.join(model_dir, f'{n}.pt'), device=d)
|
|
@@ -190,14 +195,18 @@ def compute_overlap(img_size, tile_size):
|
|
|
190
195
|
return tile_size // 4
|
|
191
196
|
|
|
192
197
|
|
|
193
|
-
def run_torchserve(img, model_path=None, eager_mode=False):
|
|
198
|
+
def run_torchserve(img, model_path=None, eager_mode=False, opt=None):
|
|
194
199
|
"""
|
|
195
200
|
eager_mode: not used in this function; put in place to be consistent with run_dask
|
|
196
201
|
so that run_wrapper() could call either this function or run_dask with
|
|
197
202
|
same syntax
|
|
203
|
+
opt: same as eager_mode
|
|
198
204
|
"""
|
|
199
205
|
buffer = BytesIO()
|
|
200
|
-
|
|
206
|
+
if opt.model == 'DeepLIIFExt':
|
|
207
|
+
torch.save(transform(img.resize((1024, 1024))), buffer)
|
|
208
|
+
else:
|
|
209
|
+
torch.save(transform(img.resize((512, 512))), buffer)
|
|
201
210
|
|
|
202
211
|
torchserve_host = os.getenv('TORCHSERVE_HOST', 'http://localhost')
|
|
203
212
|
res = requests.post(
|
|
@@ -213,33 +222,55 @@ def run_torchserve(img, model_path=None, eager_mode=False):
|
|
|
213
222
|
return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
|
|
214
223
|
|
|
215
224
|
|
|
216
|
-
def run_dask(img, model_path, eager_mode=False):
|
|
225
|
+
def run_dask(img, model_path, eager_mode=False, opt=None):
|
|
217
226
|
model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
|
|
218
|
-
nets = init_nets(model_dir, eager_mode)
|
|
219
|
-
|
|
220
|
-
|
|
227
|
+
nets = init_nets(model_dir, eager_mode, opt)
|
|
228
|
+
|
|
229
|
+
if opt.model == 'DeepLIIFExt':
|
|
230
|
+
ts = transform(img.resize((1024, 1024)))
|
|
231
|
+
else:
|
|
232
|
+
ts = transform(img.resize((512, 512)))
|
|
221
233
|
|
|
222
234
|
@delayed
|
|
223
235
|
def forward(input, model):
|
|
224
236
|
with torch.no_grad():
|
|
225
237
|
return model(input.to(next(model.parameters()).device))
|
|
238
|
+
|
|
239
|
+
if opt.model == 'DeepLIIF':
|
|
240
|
+
seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
|
|
241
|
+
|
|
242
|
+
lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
|
|
243
|
+
gens = compute(lazy_gens)[0]
|
|
244
|
+
|
|
245
|
+
lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
|
|
246
|
+
lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
|
|
247
|
+
segs = compute(lazy_segs)[0]
|
|
248
|
+
|
|
249
|
+
seg_weights = [0.25, 0.25, 0.25, 0, 0.25]
|
|
250
|
+
seg = torch.stack([torch.mul(n, w) for n, w in zip(segs.values(), seg_weights)]).sum(dim=0)
|
|
251
|
+
|
|
252
|
+
res = {k: tensor_to_pil(v) for k, v in gens.items()}
|
|
253
|
+
res['G5'] = tensor_to_pil(seg)
|
|
254
|
+
|
|
255
|
+
return res
|
|
256
|
+
elif opt.model == 'DeepLIIFExt':
|
|
257
|
+
seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
|
|
258
|
+
|
|
259
|
+
lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
|
|
260
|
+
gens = compute(lazy_gens)[0]
|
|
261
|
+
|
|
262
|
+
res = {k: tensor_to_pil(v) for k, v in gens.items()}
|
|
263
|
+
|
|
264
|
+
if opt.seg_gen:
|
|
265
|
+
lazy_segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
|
|
266
|
+
segs = compute(lazy_segs)[0]
|
|
267
|
+
res.update({k: tensor_to_pil(v) for k, v in segs.items()})
|
|
268
|
+
|
|
269
|
+
return res
|
|
270
|
+
else:
|
|
271
|
+
raise Exception(f'run_dask() not implemented for {opt.model}')
|
|
226
272
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
|
|
230
|
-
gens = compute(lazy_gens)[0]
|
|
231
|
-
|
|
232
|
-
lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
|
|
233
|
-
lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
|
|
234
|
-
segs = compute(lazy_segs)[0]
|
|
235
|
-
|
|
236
|
-
seg_weights = [0.25, 0.25, 0.25, 0, 0.25]
|
|
237
|
-
seg = torch.stack([torch.mul(n, w) for n, w in zip(segs.values(), seg_weights)]).sum(dim=0)
|
|
238
|
-
|
|
239
|
-
res = {k: tensor_to_pil(v) for k, v in gens.items()}
|
|
240
|
-
res['G5'] = tensor_to_pil(seg)
|
|
241
|
-
|
|
242
|
-
return res
|
|
273
|
+
|
|
243
274
|
|
|
244
275
|
|
|
245
276
|
def is_empty(tile):
|
|
@@ -247,17 +278,27 @@ def is_empty(tile):
|
|
|
247
278
|
return True if calculate_background_area(tile) > 98 else False
|
|
248
279
|
|
|
249
280
|
|
|
250
|
-
def run_wrapper(tile, run_fn, model_path, eager_mode=False):
|
|
251
|
-
if
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
281
|
+
def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None):
|
|
282
|
+
if opt.model == 'DeepLIIF':
|
|
283
|
+
if is_empty(tile):
|
|
284
|
+
return {
|
|
285
|
+
'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
|
|
286
|
+
'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
|
|
287
|
+
'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
288
|
+
'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
|
|
289
|
+
'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0))
|
|
290
|
+
}
|
|
291
|
+
else:
|
|
292
|
+
return run_fn(tile, model_path, eager_mode, opt)
|
|
293
|
+
elif opt.model == 'DeepLIIFExt':
|
|
294
|
+
if is_empty(tile):
|
|
295
|
+
res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
|
|
296
|
+
res.update({'GS_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)})
|
|
297
|
+
return res
|
|
298
|
+
else:
|
|
299
|
+
return run_fn(tile, model_path, eager_mode, opt)
|
|
259
300
|
else:
|
|
260
|
-
|
|
301
|
+
raise Exception(f'run_wrapper() not implemented for model {opt.model}')
|
|
261
302
|
|
|
262
303
|
|
|
263
304
|
def inference_old(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
|
|
@@ -307,71 +348,115 @@ def inference_old(img, tile_size, overlap_size, model_path, use_torchserve=False
|
|
|
307
348
|
|
|
308
349
|
|
|
309
350
|
def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
|
|
310
|
-
color_dapi=False, color_marker=False):
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
351
|
+
color_dapi=False, color_marker=False, opt=None):
|
|
352
|
+
if not opt:
|
|
353
|
+
opt = get_opt(model_path)
|
|
354
|
+
print_options(opt)
|
|
355
|
+
|
|
356
|
+
if opt.model == 'DeepLIIF':
|
|
357
|
+
rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
|
|
358
|
+
|
|
359
|
+
run_fn = run_torchserve if use_torchserve else run_dask
|
|
360
|
+
|
|
361
|
+
images = {}
|
|
362
|
+
images['Hema'] = create_image_for_stitching(tile_size, rows, cols)
|
|
363
|
+
images['DAPI'] = create_image_for_stitching(tile_size, rows, cols)
|
|
364
|
+
images['Lap2'] = create_image_for_stitching(tile_size, rows, cols)
|
|
365
|
+
images['Marker'] = create_image_for_stitching(tile_size, rows, cols)
|
|
366
|
+
images['Seg'] = create_image_for_stitching(tile_size, rows, cols)
|
|
367
|
+
|
|
368
|
+
for i in range(cols):
|
|
369
|
+
for j in range(rows):
|
|
370
|
+
tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
|
|
371
|
+
res = run_wrapper(tile, run_fn, model_path, eager_mode, opt)
|
|
372
|
+
|
|
373
|
+
stitch_tile(images['Hema'], res['G1'], tile_size, overlap_size, i, j)
|
|
374
|
+
stitch_tile(images['DAPI'], res['G2'], tile_size, overlap_size, i, j)
|
|
375
|
+
stitch_tile(images['Lap2'], res['G3'], tile_size, overlap_size, i, j)
|
|
376
|
+
stitch_tile(images['Marker'], res['G4'], tile_size, overlap_size, i, j)
|
|
377
|
+
stitch_tile(images['Seg'], res['G5'], tile_size, overlap_size, i, j)
|
|
378
|
+
|
|
379
|
+
images['Hema'] = images['Hema'].resize(img.size)
|
|
380
|
+
images['DAPI'] = images['DAPI'].resize(img.size)
|
|
381
|
+
images['Lap2'] = images['Lap2'].resize(img.size)
|
|
382
|
+
images['Marker'] = images['Marker'].resize(img.size)
|
|
383
|
+
images['Seg'] = images['Seg'].resize(img.size)
|
|
384
|
+
|
|
385
|
+
if color_dapi:
|
|
386
|
+
matrix = ( 0, 0, 0, 0,
|
|
387
|
+
299/1000, 587/1000, 114/1000, 0,
|
|
388
|
+
299/1000, 587/1000, 114/1000, 0)
|
|
389
|
+
images['DAPI'] = images['DAPI'].convert('RGB', matrix)
|
|
390
|
+
|
|
391
|
+
if color_marker:
|
|
392
|
+
matrix = (299/1000, 587/1000, 114/1000, 0,
|
|
393
|
+
299/1000, 587/1000, 114/1000, 0,
|
|
394
|
+
0, 0, 0, 0)
|
|
395
|
+
images['Marker'] = images['Marker'].convert('RGB', matrix)
|
|
396
|
+
|
|
397
|
+
return images
|
|
398
|
+
|
|
399
|
+
elif opt.model == 'DeepLIIFExt':
|
|
400
|
+
#param_dict = read_train_options(model_path)
|
|
401
|
+
#modalities_no = int(param_dict['modalities_no']) if param_dict else 4
|
|
402
|
+
#seg_gen = (param_dict['seg_gen'] == 'True') if param_dict else True
|
|
403
|
+
|
|
404
|
+
tiles = list(generate_tiles(img, tile_size, overlap_size))
|
|
405
|
+
|
|
406
|
+
run_fn = run_torchserve if use_torchserve else run_dask
|
|
407
|
+
res = [Tile(t.i, t.j, run_wrapper(t.img, run_fn, model_path, eager_mode, opt)) for t in tiles]
|
|
408
|
+
|
|
409
|
+
def get_net_tiles(n):
|
|
410
|
+
return [Tile(t.i, t.j, t.img[n]) for t in res]
|
|
411
|
+
|
|
412
|
+
images = {}
|
|
413
|
+
|
|
414
|
+
for i in range(1, opt.modalities_no + 1):
|
|
415
|
+
images['mod' + str(i)] = stitch(get_net_tiles('G_' + str(i)), tile_size, overlap_size).resize(img.size)
|
|
416
|
+
|
|
417
|
+
if opt.seg_gen:
|
|
418
|
+
for i in range(1, opt.modalities_no + 1):
|
|
419
|
+
images['Seg' + str(i)] = stitch(get_net_tiles('GS_' + str(i)), tile_size, overlap_size).resize(img.size)
|
|
420
|
+
|
|
421
|
+
return images
|
|
422
|
+
|
|
423
|
+
else:
|
|
424
|
+
raise Exception(f'inference() not implemented for model {opt.model}')
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def postprocess(orig, images, tile_size, seg_thresh=150, size_thresh='default', marker_thresh='default', size_thresh_upper=None, opt=None):
|
|
428
|
+
if opt.model == 'DeepLIIF':
|
|
429
|
+
resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
|
|
430
|
+
overlay, refined, scoring = compute_results(np.array(orig), np.array(images['Seg']),
|
|
431
|
+
np.array(images['Marker'].convert('L')), resolution,
|
|
432
|
+
seg_thresh, size_thresh, marker_thresh, size_thresh_upper)
|
|
433
|
+
processed_images = {}
|
|
434
|
+
processed_images['SegOverlaid'] = Image.fromarray(overlay)
|
|
435
|
+
processed_images['SegRefined'] = Image.fromarray(refined)
|
|
436
|
+
return processed_images, scoring
|
|
437
|
+
|
|
438
|
+
elif opt.model == 'DeepLIIFExt':
|
|
439
|
+
resolution = '40x' if tile_size > 768 else ('20x' if tile_size > 384 else '10x')
|
|
440
|
+
processed_images = {}
|
|
441
|
+
scoring = {}
|
|
442
|
+
for img_name in list(images.keys()):
|
|
443
|
+
if 'Seg' in img_name:
|
|
444
|
+
seg_img = images[img_name]
|
|
445
|
+
overlay, refined, score = compute_results(np.array(orig), np.array(images[img_name]),
|
|
446
|
+
None, resolution,
|
|
447
|
+
seg_thresh, size_thresh, marker_thresh, size_thresh_upper)
|
|
448
|
+
|
|
449
|
+
processed_images[img_name + '_Overlaid'] = Image.fromarray(overlay)
|
|
450
|
+
processed_images[img_name + '_Refined'] = Image.fromarray(refined)
|
|
451
|
+
scoring[img_name] = score
|
|
452
|
+
return processed_images, scoring
|
|
369
453
|
|
|
370
|
-
|
|
454
|
+
else:
|
|
455
|
+
raise Exception(f'postprocess() not implemented for model {opt.model}')
|
|
371
456
|
|
|
372
457
|
|
|
373
458
|
def infer_modalities(img, tile_size, model_dir, eager_mode=False,
|
|
374
|
-
color_dapi=False, color_marker=False):
|
|
459
|
+
color_dapi=False, color_marker=False, opt=None):
|
|
375
460
|
"""
|
|
376
461
|
This function is used to infer modalities for the given image using a trained model.
|
|
377
462
|
:param img: The input image.
|
|
@@ -379,6 +464,11 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
|
|
|
379
464
|
:param model_dir: The directory containing serialized model files.
|
|
380
465
|
:return: The inferred modalities and the segmentation mask.
|
|
381
466
|
"""
|
|
467
|
+
if opt is None:
|
|
468
|
+
opt = get_opt(model_dir)
|
|
469
|
+
opt.use_dp = False
|
|
470
|
+
print_options(opt)
|
|
471
|
+
|
|
382
472
|
if not tile_size:
|
|
383
473
|
tile_size = check_multi_scale(Image.open('./images/target.png').convert('L'),
|
|
384
474
|
img.convert('L'))
|
|
@@ -391,12 +481,16 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
|
|
|
391
481
|
model_path=model_dir,
|
|
392
482
|
eager_mode=eager_mode,
|
|
393
483
|
color_dapi=color_dapi,
|
|
394
|
-
color_marker=color_marker
|
|
484
|
+
color_marker=color_marker,
|
|
485
|
+
opt=opt
|
|
395
486
|
)
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
487
|
+
|
|
488
|
+
if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
|
|
489
|
+
post_images, scoring = postprocess(img, images, tile_size, opt=opt)
|
|
490
|
+
images = {**images, **post_images}
|
|
491
|
+
return images, scoring
|
|
492
|
+
else:
|
|
493
|
+
return images, None
|
|
400
494
|
|
|
401
495
|
|
|
402
496
|
def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000):
|
deepliif/models/base_model.py
CHANGED
|
@@ -4,6 +4,8 @@ from collections import OrderedDict
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
5
|
from . import networks
|
|
6
6
|
from ..util import disable_batchnorm_tracking_stats
|
|
7
|
+
from deepliif.util import *
|
|
8
|
+
import itertools
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
class BaseModel(ABC):
|
|
@@ -35,8 +37,9 @@ class BaseModel(ABC):
|
|
|
35
37
|
self.is_train = opt.is_train
|
|
36
38
|
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
|
37
39
|
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
|
|
38
|
-
if opt.preprocess != 'scale_width':
|
|
40
|
+
if opt.phase == 'train' and opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
|
39
41
|
torch.backends.cudnn.benchmark = True
|
|
42
|
+
# especially for inference, cudnn benchmark can cause excessive usage of GPU memory for the first image in the sequence in order to find the best conv alg which is not necessary.
|
|
40
43
|
self.loss_names = []
|
|
41
44
|
self.model_names = []
|
|
42
45
|
self.visual_names = []
|
|
@@ -89,7 +92,10 @@ class BaseModel(ABC):
|
|
|
89
92
|
"""Make models eval mode during test time"""
|
|
90
93
|
for name in self.model_names:
|
|
91
94
|
if isinstance(name, str):
|
|
92
|
-
|
|
95
|
+
if '_' in name:
|
|
96
|
+
net = getattr(self, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
|
|
97
|
+
else:
|
|
98
|
+
net = getattr(self, 'net' + name)
|
|
93
99
|
net.eval()
|
|
94
100
|
net = disable_batchnorm_tracking_stats(net)
|
|
95
101
|
|
|
@@ -127,7 +133,13 @@ class BaseModel(ABC):
|
|
|
127
133
|
visual_ret = OrderedDict()
|
|
128
134
|
for name in self.visual_names:
|
|
129
135
|
if isinstance(name, str):
|
|
130
|
-
|
|
136
|
+
if not hasattr(self, name):
|
|
137
|
+
if len(name.split('_')) == 2:
|
|
138
|
+
visual_ret[name] = getattr(self, name.split('_')[0])[int(name.split('_')[-1]) -1]
|
|
139
|
+
else:
|
|
140
|
+
visual_ret[name] = getattr(self, name.split('_')[0] + '_' + name.split('_')[1])[int(name.split('_')[-1]) - 1]
|
|
141
|
+
else:
|
|
142
|
+
visual_ret[name] = getattr(self, name)
|
|
131
143
|
return visual_ret
|
|
132
144
|
|
|
133
145
|
def get_current_losses(self):
|
|
@@ -135,7 +147,16 @@ class BaseModel(ABC):
|
|
|
135
147
|
errors_ret = OrderedDict()
|
|
136
148
|
for name in self.loss_names:
|
|
137
149
|
if isinstance(name, str):
|
|
138
|
-
|
|
150
|
+
if not hasattr(self, 'loss_'+name): # appears in DeepLIIFExt
|
|
151
|
+
if len(name.split('_')) == 2:
|
|
152
|
+
errors_ret[name] = float(getattr(self, 'loss_' + name.split('_')[0])[int(
|
|
153
|
+
name.split('_')[-1]) - 1]) # float(...) works for both scalar tensor and float number
|
|
154
|
+
else:
|
|
155
|
+
errors_ret[name] = float(getattr(self, 'loss_' + name.split('_')[0] + '_' + name.split('_')[1])[int(
|
|
156
|
+
name.split('_')[-1]) - 1]) # float(...) works for both scalar tensor and float number
|
|
157
|
+
else: # single numeric value
|
|
158
|
+
errors_ret[name] = float(getattr(self, 'loss_' + name))
|
|
159
|
+
|
|
139
160
|
return errors_ret
|
|
140
161
|
|
|
141
162
|
def save_networks(self, epoch, save_from_one_process=False):
|
|
@@ -151,7 +172,10 @@ class BaseModel(ABC):
|
|
|
151
172
|
if isinstance(name, str):
|
|
152
173
|
save_filename = '%s_net_%s.pth' % (epoch, name)
|
|
153
174
|
save_path = os.path.join(self.save_dir, save_filename)
|
|
154
|
-
|
|
175
|
+
if '_' in name:
|
|
176
|
+
net = getattr(self, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
|
|
177
|
+
else:
|
|
178
|
+
net = getattr(self, 'net' + name)
|
|
155
179
|
|
|
156
180
|
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
|
157
181
|
torch.save(net.module.cpu().state_dict(), save_path)
|
|
@@ -219,13 +243,32 @@ class BaseModel(ABC):
|
|
|
219
243
|
if isinstance(name, str):
|
|
220
244
|
load_filename = '%s_net_%s.pth' % (epoch, name)
|
|
221
245
|
load_path = os.path.join(self.save_dir, load_filename)
|
|
222
|
-
|
|
246
|
+
if '_' in name:
|
|
247
|
+
net = getattr(self, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
|
|
248
|
+
else:
|
|
249
|
+
net = getattr(self, 'net' + name)
|
|
223
250
|
if isinstance(net, torch.nn.DataParallel):
|
|
224
251
|
net = net.module
|
|
252
|
+
|
|
253
|
+
self.set_requires_grad(net,self.opt.is_train)
|
|
254
|
+
# check if gradients are disabled
|
|
255
|
+
names_layer_requires_grad = []
|
|
256
|
+
for name, param in net.named_parameters():
|
|
257
|
+
if param.requires_grad:
|
|
258
|
+
names_layer_requires_grad.append(name)
|
|
259
|
+
|
|
225
260
|
print('loading the model from %s' % load_path)
|
|
226
261
|
# if you are using PyTorch newer than 0.4 (e.g., built from
|
|
227
262
|
# GitHub source), you can remove str() on self.device
|
|
228
|
-
|
|
263
|
+
|
|
264
|
+
if self.opt.is_train or self.opt.use_dp:
|
|
265
|
+
device = self.device
|
|
266
|
+
else:
|
|
267
|
+
device = torch.device('cpu') # load in cpu first; later in __inite__.py::init_nets we will move it to the specified device
|
|
268
|
+
|
|
269
|
+
net.to(device)
|
|
270
|
+
state_dict = torch.load(load_path, map_location=str(device))
|
|
271
|
+
|
|
229
272
|
if hasattr(state_dict, '_metadata'):
|
|
230
273
|
del state_dict._metadata
|
|
231
274
|
|
|
@@ -243,7 +286,10 @@ class BaseModel(ABC):
|
|
|
243
286
|
print('---------- Networks initialized -------------')
|
|
244
287
|
for name in self.model_names:
|
|
245
288
|
if isinstance(name, str):
|
|
246
|
-
|
|
289
|
+
if '_' in name:
|
|
290
|
+
net = getattr(self, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
|
|
291
|
+
else:
|
|
292
|
+
net = getattr(self, 'net' + name)
|
|
247
293
|
num_params = 0
|
|
248
294
|
for param in net.parameters():
|
|
249
295
|
num_params += param.numel()
|