deepliif 1.1.11__py3-none-any.whl → 1.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli.py +354 -67
- deepliif/data/__init__.py +7 -7
- deepliif/data/aligned_dataset.py +2 -3
- deepliif/data/unaligned_dataset.py +38 -19
- deepliif/models/CycleGAN_model.py +282 -0
- deepliif/models/DeepLIIFExt_model.py +47 -25
- deepliif/models/DeepLIIF_model.py +69 -19
- deepliif/models/SDG_model.py +57 -26
- deepliif/models/__init__ - different weighted.py +762 -0
- deepliif/models/__init__ - run_dask_multi dev.py +943 -0
- deepliif/models/__init__ - time gens.py +792 -0
- deepliif/models/__init__ - timings.py +764 -0
- deepliif/models/__init__.py +359 -265
- deepliif/models/att_unet.py +199 -0
- deepliif/models/base_model.py +32 -8
- deepliif/models/networks.py +108 -34
- deepliif/options/__init__.py +49 -5
- deepliif/postprocessing.py +1034 -227
- deepliif/postprocessing__OLD__DELETE.py +440 -0
- deepliif/util/__init__.py +86 -65
- deepliif/util/visualizer.py +106 -19
- {deepliif-1.1.11.dist-info → deepliif-1.1.13.dist-info}/METADATA +75 -24
- deepliif-1.1.13.dist-info/RECORD +42 -0
- deepliif-1.1.11.dist-info/RECORD +0 -35
- {deepliif-1.1.11.dist-info → deepliif-1.1.13.dist-info}/LICENSE.md +0 -0
- {deepliif-1.1.11.dist-info → deepliif-1.1.13.dist-info}/WHEEL +0 -0
- {deepliif-1.1.11.dist-info → deepliif-1.1.13.dist-info}/entry_points.txt +0 -0
- {deepliif-1.1.11.dist-info → deepliif-1.1.13.dist-info}/top_level.txt +0 -0
deepliif/models/__init__.py
CHANGED
|
@@ -12,7 +12,7 @@ In the function <__init__>, you need to define four lists:
|
|
|
12
12
|
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
|
13
13
|
-- self.model_names (str list): define networks used in our training.
|
|
14
14
|
-- self.visual_names (str list): specify the images that you want to display and save.
|
|
15
|
-
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See
|
|
15
|
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See CycleGAN_model.py for an usage.
|
|
16
16
|
|
|
17
17
|
Now you can use the model class by specifying flag '--model dummy'.
|
|
18
18
|
See our template model class 'template_model.py' for more details.
|
|
@@ -23,17 +23,22 @@ import itertools
|
|
|
23
23
|
import importlib
|
|
24
24
|
from functools import lru_cache
|
|
25
25
|
from io import BytesIO
|
|
26
|
+
import json
|
|
27
|
+
import math
|
|
26
28
|
|
|
27
29
|
import requests
|
|
28
30
|
import torch
|
|
29
31
|
from PIL import Image
|
|
32
|
+
Image.MAX_IMAGE_PIXELS = None
|
|
33
|
+
|
|
30
34
|
import numpy as np
|
|
31
35
|
from dask import delayed, compute
|
|
32
36
|
|
|
33
37
|
from deepliif.util import *
|
|
34
|
-
from deepliif.util.util import tensor_to_pil
|
|
38
|
+
from deepliif.util.util import tensor_to_pil
|
|
35
39
|
from deepliif.data import transform
|
|
36
|
-
from deepliif.postprocessing import
|
|
40
|
+
from deepliif.postprocessing import compute_final_results, compute_cell_results
|
|
41
|
+
from deepliif.postprocessing import encode_cell_data_v4, decode_cell_data_v4
|
|
37
42
|
from deepliif.options import Options, print_options
|
|
38
43
|
|
|
39
44
|
from .base_model import BaseModel
|
|
@@ -115,14 +120,19 @@ def load_torchscript_model(model_pt_path, device):
|
|
|
115
120
|
|
|
116
121
|
|
|
117
122
|
|
|
118
|
-
def load_eager_models(opt, devices):
|
|
123
|
+
def load_eager_models(opt, devices=None):
|
|
119
124
|
# create a model given model and other options
|
|
120
125
|
model = create_model(opt)
|
|
121
126
|
# regular setup: load and print networks; create schedulers
|
|
122
127
|
model.setup(opt)
|
|
123
128
|
|
|
124
129
|
nets = {}
|
|
125
|
-
|
|
130
|
+
if devices:
|
|
131
|
+
model_names = list(devices.keys())
|
|
132
|
+
else:
|
|
133
|
+
model_names = model.model_names
|
|
134
|
+
|
|
135
|
+
for name in model_names:#model.model_names:
|
|
126
136
|
if isinstance(name, str):
|
|
127
137
|
if '_' in name:
|
|
128
138
|
net = getattr(model, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
|
|
@@ -138,7 +148,8 @@ def load_eager_models(opt, devices):
|
|
|
138
148
|
net = net.module
|
|
139
149
|
|
|
140
150
|
nets[name] = net
|
|
141
|
-
|
|
151
|
+
if devices:
|
|
152
|
+
nets[name].to(devices[name])
|
|
142
153
|
|
|
143
154
|
return nets
|
|
144
155
|
|
|
@@ -154,8 +165,7 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
|
|
|
154
165
|
"""
|
|
155
166
|
if opt is None:
|
|
156
167
|
opt = get_opt(model_dir, mode=phase)
|
|
157
|
-
|
|
158
|
-
#print_options(opt)
|
|
168
|
+
opt.use_dp = False
|
|
159
169
|
|
|
160
170
|
if opt.model == 'DeepLIIF':
|
|
161
171
|
net_groups = [
|
|
@@ -170,12 +180,16 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
|
|
|
170
180
|
net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
|
|
171
181
|
else:
|
|
172
182
|
net_groups = [(f'G_{i+1}',) for i in range(opt.modalities_no)]
|
|
183
|
+
elif opt.model == 'CycleGAN':
|
|
184
|
+
if opt.BtoA:
|
|
185
|
+
net_groups = [(f'GB_{i+1}',) for i in range(opt.modalities_no)]
|
|
186
|
+
else:
|
|
187
|
+
net_groups = [(f'GA_{i+1}',) for i in range(opt.modalities_no)]
|
|
173
188
|
else:
|
|
174
189
|
raise Exception(f'init_nets() not implemented for model {opt.model}')
|
|
175
190
|
|
|
176
191
|
number_of_gpus_all = torch.cuda.device_count()
|
|
177
|
-
number_of_gpus = len(opt.gpu_ids)
|
|
178
|
-
#print(number_of_gpus)
|
|
192
|
+
number_of_gpus = min(len(opt.gpu_ids),number_of_gpus_all)
|
|
179
193
|
|
|
180
194
|
if number_of_gpus > 0:
|
|
181
195
|
mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
|
|
@@ -203,12 +217,13 @@ def compute_overlap(img_size, tile_size):
|
|
|
203
217
|
return tile_size // 4
|
|
204
218
|
|
|
205
219
|
|
|
206
|
-
def run_torchserve(img, model_path=None, eager_mode=False, opt=None):
|
|
220
|
+
def run_torchserve(img, model_path=None, eager_mode=False, opt=None, seg_only=False):
|
|
207
221
|
"""
|
|
208
222
|
eager_mode: not used in this function; put in place to be consistent with run_dask
|
|
209
223
|
so that run_wrapper() could call either this function or run_dask with
|
|
210
224
|
same syntax
|
|
211
225
|
opt: same as eager_mode
|
|
226
|
+
seg_only: same as eager_mode
|
|
212
227
|
"""
|
|
213
228
|
buffer = BytesIO()
|
|
214
229
|
torch.save(transform(img.resize((opt.scale_size, opt.scale_size))), buffer)
|
|
@@ -227,9 +242,10 @@ def run_torchserve(img, model_path=None, eager_mode=False, opt=None):
|
|
|
227
242
|
return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
|
|
228
243
|
|
|
229
244
|
|
|
230
|
-
def run_dask(img, model_path, eager_mode=False, opt=None):
|
|
245
|
+
def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
|
|
231
246
|
model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
|
|
232
247
|
nets = init_nets(model_dir, eager_mode, opt)
|
|
248
|
+
use_dask = True if opt.norm != 'spectral' else False
|
|
233
249
|
|
|
234
250
|
if opt.input_no > 1 or opt.model == 'SDG':
|
|
235
251
|
l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
|
|
@@ -238,45 +254,69 @@ def run_dask(img, model_path, eager_mode=False, opt=None):
|
|
|
238
254
|
ts = transform(img.resize((opt.scale_size, opt.scale_size)))
|
|
239
255
|
|
|
240
256
|
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
257
|
+
if use_dask:
|
|
258
|
+
@delayed
|
|
259
|
+
def forward(input, model):
|
|
260
|
+
with torch.no_grad():
|
|
261
|
+
return model(input.to(next(model.parameters()).device))
|
|
262
|
+
else: # some train settings like spectral norm some how in inference mode is not compatible with dask
|
|
263
|
+
def forward(input, model):
|
|
264
|
+
with torch.no_grad():
|
|
265
|
+
return model(input.to(next(model.parameters()).device))
|
|
245
266
|
|
|
246
267
|
if opt.model == 'DeepLIIF':
|
|
268
|
+
weights = {
|
|
269
|
+
'G51': 0.25, # IHC
|
|
270
|
+
'G52': 0.25, # Hema
|
|
271
|
+
'G53': 0.25, # DAPI
|
|
272
|
+
'G54': 0.00, # Lap2
|
|
273
|
+
'G55': 0.25, # Marker
|
|
274
|
+
}
|
|
275
|
+
|
|
247
276
|
seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
|
|
277
|
+
if seg_only:
|
|
278
|
+
seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0}
|
|
248
279
|
|
|
249
280
|
lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
|
|
281
|
+
if 'G4' not in seg_map:
|
|
282
|
+
lazy_gens['G4'] = forward(ts, nets['G4'])
|
|
250
283
|
gens = compute(lazy_gens)[0]
|
|
251
284
|
|
|
252
285
|
lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
|
|
253
|
-
|
|
286
|
+
if not seg_only or weights['G51'] != 0:
|
|
287
|
+
lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
|
|
254
288
|
segs = compute(lazy_segs)[0]
|
|
255
289
|
|
|
256
|
-
weights = {
|
|
257
|
-
'G51': 0.25, # IHC
|
|
258
|
-
'G52': 0.25, # Hema
|
|
259
|
-
'G53': 0.25, # DAPI
|
|
260
|
-
'G54': 0.00, # Lap2
|
|
261
|
-
'G55': 0.25, # Marker
|
|
262
|
-
}
|
|
263
290
|
seg = torch.stack([torch.mul(segs[k], weights[k]) for k in segs.keys()]).sum(dim=0)
|
|
264
291
|
|
|
265
|
-
|
|
292
|
+
if seg_only:
|
|
293
|
+
res = {'G4': tensor_to_pil(gens['G4'])} if 'G4' in gens else {}
|
|
294
|
+
else:
|
|
295
|
+
res = {k: tensor_to_pil(v) for k, v in gens.items()}
|
|
296
|
+
res.update({k: tensor_to_pil(v) for k, v in segs.items()})
|
|
266
297
|
res['G5'] = tensor_to_pil(seg)
|
|
267
298
|
|
|
268
299
|
return res
|
|
269
|
-
elif opt.model in ['DeepLIIFExt','SDG']:
|
|
270
|
-
|
|
300
|
+
elif opt.model in ['DeepLIIFExt','SDG','CycleGAN']:
|
|
301
|
+
if opt.model == 'CycleGAN':
|
|
302
|
+
seg_map = {f'GB_{i+1}':None for i in range(opt.modalities_no)} if opt.BtoA else {f'GA_{i+1}':None for i in range(opt.modalities_no)}
|
|
303
|
+
else:
|
|
304
|
+
seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
|
|
271
305
|
|
|
272
|
-
|
|
273
|
-
|
|
306
|
+
if use_dask:
|
|
307
|
+
lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
|
|
308
|
+
gens = compute(lazy_gens)[0]
|
|
309
|
+
else:
|
|
310
|
+
gens = {k: forward(ts, nets[k]) for k in seg_map}
|
|
274
311
|
|
|
275
312
|
res = {k: tensor_to_pil(v) for k, v in gens.items()}
|
|
276
313
|
|
|
277
314
|
if opt.seg_gen:
|
|
278
|
-
|
|
279
|
-
|
|
315
|
+
if use_dask:
|
|
316
|
+
lazy_segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
|
|
317
|
+
segs = compute(lazy_segs)[0]
|
|
318
|
+
else:
|
|
319
|
+
segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
|
|
280
320
|
res.update({k: tensor_to_pil(v) for k, v in segs.items()})
|
|
281
321
|
|
|
282
322
|
return res
|
|
@@ -284,14 +324,6 @@ def run_dask(img, model_path, eager_mode=False, opt=None):
|
|
|
284
324
|
raise Exception(f'run_dask() not fully implemented for {opt.model}')
|
|
285
325
|
|
|
286
326
|
|
|
287
|
-
def is_empty_old(tile):
|
|
288
|
-
# return True if np.mean(np.array(tile) - np.array(mean_background_val)) < 40 else False
|
|
289
|
-
if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
|
|
290
|
-
return all([True if calculate_background_area(t) > 98 else False for t in tile])
|
|
291
|
-
else:
|
|
292
|
-
return True if calculate_background_area(tile) > 98 else False
|
|
293
|
-
|
|
294
|
-
|
|
295
327
|
def is_empty(tile):
|
|
296
328
|
thresh = 15
|
|
297
329
|
if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
|
|
@@ -300,18 +332,29 @@ def is_empty(tile):
|
|
|
300
332
|
return True if np.max(image_variance_rgb(tile)) < thresh else False
|
|
301
333
|
|
|
302
334
|
|
|
303
|
-
def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None):
|
|
335
|
+
def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None, seg_only=False):
|
|
304
336
|
if opt.model == 'DeepLIIF':
|
|
305
337
|
if is_empty(tile):
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
338
|
+
if seg_only:
|
|
339
|
+
return {
|
|
340
|
+
'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
|
|
341
|
+
'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
342
|
+
}
|
|
343
|
+
else :
|
|
344
|
+
return {
|
|
345
|
+
'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
|
|
346
|
+
'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
|
|
347
|
+
'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
348
|
+
'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
|
|
349
|
+
'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
350
|
+
'G51': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
351
|
+
'G52': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
352
|
+
'G53': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
353
|
+
'G54': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
354
|
+
'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
355
|
+
}
|
|
313
356
|
else:
|
|
314
|
-
return run_fn(tile, model_path, eager_mode, opt)
|
|
357
|
+
return run_fn(tile, model_path, eager_mode, opt, seg_only)
|
|
315
358
|
elif opt.model in ['DeepLIIFExt', 'SDG']:
|
|
316
359
|
if is_empty(tile):
|
|
317
360
|
res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
|
|
@@ -319,178 +362,20 @@ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None):
|
|
|
319
362
|
return res
|
|
320
363
|
else:
|
|
321
364
|
return run_fn(tile, model_path, eager_mode, opt)
|
|
365
|
+
elif opt.model in ['CycleGAN']:
|
|
366
|
+
if is_empty(tile):
|
|
367
|
+
net_names = ['GB_{i+1}' for i in range(opt.modalities_no)] if opt.BtoA else [f'GA_{i+1}' for i in range(opt.modalities_no)]
|
|
368
|
+
res = {net_name: Image.new(mode='RGB', size=(512, 512)) for net_name in net_names}
|
|
369
|
+
return res
|
|
370
|
+
else:
|
|
371
|
+
return run_fn(tile, model_path, eager_mode, opt)
|
|
322
372
|
else:
|
|
323
373
|
raise Exception(f'run_wrapper() not implemented for model {opt.model}')
|
|
324
374
|
|
|
325
375
|
|
|
326
|
-
def inference_old(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
|
|
327
|
-
color_dapi=False, color_marker=False):
|
|
328
|
-
|
|
329
|
-
tiles = list(generate_tiles(img, tile_size, overlap_size))
|
|
330
|
-
|
|
331
|
-
run_fn = run_torchserve if use_torchserve else run_dask
|
|
332
|
-
# res = [Tile(t.i, t.j, run_fn(t.img, model_path)) for t in tiles]
|
|
333
|
-
res = [Tile(t.i, t.j, run_wrapper(t.img, run_fn, model_path, eager_mode)) for t in tiles]
|
|
334
|
-
|
|
335
|
-
def get_net_tiles(n):
|
|
336
|
-
return [Tile(t.i, t.j, t.img[n]) for t in res]
|
|
337
|
-
|
|
338
|
-
images = {}
|
|
339
|
-
|
|
340
|
-
images['Hema'] = stitch(get_net_tiles('G1'), tile_size, overlap_size).resize(img.size)
|
|
341
|
-
|
|
342
|
-
# images['DAPI'] = stitch(
|
|
343
|
-
# [Tile(t.i, t.j, adjust_background_tile(dt.img))
|
|
344
|
-
# for t, dt in zip(tiles, get_net_tiles('G2'))],
|
|
345
|
-
# tile_size, overlap_size).resize(img.size)
|
|
346
|
-
# dapi_pix = np.array(images['DAPI'])
|
|
347
|
-
# dapi_pix[:, :, 0] = 0
|
|
348
|
-
# images['DAPI'] = Image.fromarray(dapi_pix)
|
|
349
|
-
|
|
350
|
-
images['DAPI'] = stitch(get_net_tiles('G2'), tile_size, overlap_size).resize(img.size)
|
|
351
|
-
dapi_pix = np.array(images['DAPI'].convert('L').convert('RGB'))
|
|
352
|
-
if color_dapi:
|
|
353
|
-
dapi_pix[:, :, 0] = 0
|
|
354
|
-
images['DAPI'] = Image.fromarray(dapi_pix)
|
|
355
|
-
images['Lap2'] = stitch(get_net_tiles('G3'), tile_size, overlap_size).resize(img.size)
|
|
356
|
-
images['Marker'] = stitch(get_net_tiles('G4'), tile_size, overlap_size).resize(img.size)
|
|
357
|
-
marker_pix = np.array(images['Marker'].convert('L').convert('RGB'))
|
|
358
|
-
if color_marker:
|
|
359
|
-
marker_pix[:, :, 2] = 0
|
|
360
|
-
images['Marker'] = Image.fromarray(marker_pix)
|
|
361
|
-
|
|
362
|
-
# images['Marker'] = stitch(
|
|
363
|
-
# [Tile(t.i, t.j, kt.img)
|
|
364
|
-
# for t, kt in zip(tiles, get_net_tiles('G4'))],
|
|
365
|
-
# tile_size, overlap_size).resize(img.size)
|
|
366
|
-
|
|
367
|
-
images['Seg'] = stitch(get_net_tiles('G5'), tile_size, overlap_size).resize(img.size)
|
|
368
|
-
|
|
369
|
-
return images
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
def inference_old2(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
|
|
373
|
-
color_dapi=False, color_marker=False, opt=None):
|
|
374
|
-
if not opt:
|
|
375
|
-
opt = get_opt(model_path)
|
|
376
|
-
#print_options(opt)
|
|
377
|
-
|
|
378
|
-
if opt.model == 'DeepLIIF':
|
|
379
|
-
rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
|
|
380
|
-
|
|
381
|
-
run_fn = run_torchserve if use_torchserve else run_dask
|
|
382
|
-
|
|
383
|
-
images = {}
|
|
384
|
-
d_modality2net = {'Hema':'G1',
|
|
385
|
-
'DAPI':'G2',
|
|
386
|
-
'Lap2':'G3',
|
|
387
|
-
'Marker':'G4',
|
|
388
|
-
'Seg':'G5'}
|
|
389
|
-
|
|
390
|
-
for k in d_modality2net.keys():
|
|
391
|
-
images[k] = create_image_for_stitching(tile_size, rows, cols)
|
|
392
|
-
|
|
393
|
-
for i in range(cols):
|
|
394
|
-
for j in range(rows):
|
|
395
|
-
tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
|
|
396
|
-
res = run_wrapper(tile, run_fn, model_path, eager_mode, opt)
|
|
397
|
-
|
|
398
|
-
for modality_name, net_name in d_modality2net.items():
|
|
399
|
-
stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
|
|
400
|
-
|
|
401
|
-
for modality_name, output_img in images.items():
|
|
402
|
-
images[modality_name] = output_img.resize(img.size)
|
|
403
|
-
|
|
404
|
-
if color_dapi:
|
|
405
|
-
matrix = ( 0, 0, 0, 0,
|
|
406
|
-
299/1000, 587/1000, 114/1000, 0,
|
|
407
|
-
299/1000, 587/1000, 114/1000, 0)
|
|
408
|
-
images['DAPI'] = images['DAPI'].convert('RGB', matrix)
|
|
409
|
-
|
|
410
|
-
if color_marker:
|
|
411
|
-
matrix = (299/1000, 587/1000, 114/1000, 0,
|
|
412
|
-
299/1000, 587/1000, 114/1000, 0,
|
|
413
|
-
0, 0, 0, 0)
|
|
414
|
-
images['Marker'] = images['Marker'].convert('RGB', matrix)
|
|
415
|
-
|
|
416
|
-
return images
|
|
417
|
-
|
|
418
|
-
elif opt.model == 'DeepLIIFExt':
|
|
419
|
-
#param_dict = read_train_options(model_path)
|
|
420
|
-
#modalities_no = int(param_dict['modalities_no']) if param_dict else 4
|
|
421
|
-
#seg_gen = (param_dict['seg_gen'] == 'True') if param_dict else True
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
|
|
425
|
-
run_fn = run_torchserve if use_torchserve else run_dask
|
|
426
|
-
|
|
427
|
-
def get_net_tiles(n):
|
|
428
|
-
return [Tile(t.i, t.j, t.img[n]) for t in res]
|
|
429
|
-
|
|
430
|
-
images = {}
|
|
431
|
-
d_modality2net = {f'mod{i}':f'G_{i}' for i in range(1, opt.modalities_no + 1)}
|
|
432
|
-
if opt.seg_gen:
|
|
433
|
-
d_modality2net.update({f'Seg{i}':f'GS_{i}' for i in range(1, opt.modalities_no + 1)})
|
|
434
|
-
|
|
435
|
-
for k in d_modality2net.keys():
|
|
436
|
-
images[k] = create_image_for_stitching(tile_size, rows, cols)
|
|
437
|
-
|
|
438
|
-
for i in range(cols):
|
|
439
|
-
for j in range(rows):
|
|
440
|
-
tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
|
|
441
|
-
res = run_wrapper(tile, run_fn, model_path, eager_mode, opt)
|
|
442
|
-
|
|
443
|
-
for modality_name, net_name in d_modality2net.items():
|
|
444
|
-
stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
|
|
445
|
-
|
|
446
|
-
for modality_name, output_img in images.items():
|
|
447
|
-
images[modality_name] = output_img.resize(img.size)
|
|
448
|
-
|
|
449
|
-
return images
|
|
450
|
-
|
|
451
|
-
elif opt.model == 'SDG':
|
|
452
|
-
# SDG could have multiple input images / modalities
|
|
453
|
-
# the input hence could be a rectangle
|
|
454
|
-
# we split the input to get each modality image one by one
|
|
455
|
-
# then create tiles for each of the modality images
|
|
456
|
-
# tile_pair is a list that contains the tiles at the given location for each modality image
|
|
457
|
-
# l_tile_pair is a list of tile_pair that covers all locations
|
|
458
|
-
# for inference, each tile_pair is used to get the output at the given location
|
|
459
|
-
w, h = img.size
|
|
460
|
-
w2 = int(w / opt.input_no)
|
|
461
|
-
|
|
462
|
-
l_img = []
|
|
463
|
-
for i in range(opt.input_no):
|
|
464
|
-
img_i = img.crop((w2 * i, 0, w2 * (i+1), h))
|
|
465
|
-
rescaled_img_i, rows, cols = format_image_for_tiling(img_i, tile_size, overlap_size)
|
|
466
|
-
l_img.append(rescaled_img_i)
|
|
467
|
-
|
|
468
|
-
run_fn = run_torchserve if use_torchserve else run_dask
|
|
469
|
-
|
|
470
|
-
images = {}
|
|
471
|
-
d_modality2net = {f'mod{i}':f'G_{i}' for i in range(1, opt.modalities_no + 1)}
|
|
472
|
-
for k in d_modality2net.keys():
|
|
473
|
-
images[k] = create_image_for_stitching(tile_size, rows, cols)
|
|
474
|
-
|
|
475
|
-
for i in range(cols):
|
|
476
|
-
for j in range(rows):
|
|
477
|
-
tile_pair = [extract_tile(rescaled, tile_size, overlap_size, i, j) for rescaled in l_img]
|
|
478
|
-
res = run_wrapper(tile_pair, run_fn, model_path, eager_mode, opt)
|
|
479
|
-
|
|
480
|
-
for modality_name, net_name in d_modality2net.items():
|
|
481
|
-
stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
|
|
482
|
-
|
|
483
|
-
for modality_name, output_img in images.items():
|
|
484
|
-
images[modality_name] = output_img.resize((w2,w2))
|
|
485
|
-
|
|
486
|
-
return images
|
|
487
|
-
|
|
488
|
-
else:
|
|
489
|
-
raise Exception(f'inference() not implemented for model {opt.model}')
|
|
490
|
-
|
|
491
|
-
|
|
492
376
|
def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
|
|
493
|
-
eager_mode=False, color_dapi=False, color_marker=False, opt=None
|
|
377
|
+
eager_mode=False, color_dapi=False, color_marker=False, opt=None,
|
|
378
|
+
return_seg_intermediate=False, seg_only=False):
|
|
494
379
|
if not opt:
|
|
495
380
|
opt = get_opt(model_path)
|
|
496
381
|
#print_options(opt)
|
|
@@ -508,23 +393,36 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
|
|
|
508
393
|
|
|
509
394
|
tiler = InferenceTiler(orig, tile_size, overlap_size)
|
|
510
395
|
for tile in tiler:
|
|
511
|
-
tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt))
|
|
396
|
+
tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt, seg_only))
|
|
512
397
|
results = tiler.results()
|
|
513
398
|
|
|
514
399
|
if opt.model == 'DeepLIIF':
|
|
515
|
-
|
|
516
|
-
'
|
|
517
|
-
'
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
400
|
+
if seg_only:
|
|
401
|
+
images = {'Seg': results['G5']}
|
|
402
|
+
if 'G4' in results:
|
|
403
|
+
images.update({'Marker': results['G4']})
|
|
404
|
+
else:
|
|
405
|
+
images = {
|
|
406
|
+
'Hema': results['G1'],
|
|
407
|
+
'DAPI': results['G2'],
|
|
408
|
+
'Lap2': results['G3'],
|
|
409
|
+
'Marker': results['G4'],
|
|
410
|
+
'Seg': results['G5'],
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
if return_seg_intermediate and not seg_only:
|
|
414
|
+
images.update({'IHC_s':results['G51'],
|
|
415
|
+
'Hema_s':results['G52'],
|
|
416
|
+
'DAPI_s':results['G53'],
|
|
417
|
+
'Lap2_s':results['G54'],
|
|
418
|
+
'Marker_s':results['G55'],})
|
|
419
|
+
|
|
420
|
+
if color_dapi and not seg_only:
|
|
523
421
|
matrix = ( 0, 0, 0, 0,
|
|
524
422
|
299/1000, 587/1000, 114/1000, 0,
|
|
525
423
|
299/1000, 587/1000, 114/1000, 0)
|
|
526
424
|
images['DAPI'] = images['DAPI'].convert('RGB', matrix)
|
|
527
|
-
if color_marker:
|
|
425
|
+
if color_marker and not seg_only:
|
|
528
426
|
matrix = (299/1000, 587/1000, 114/1000, 0,
|
|
529
427
|
299/1000, 587/1000, 114/1000, 0,
|
|
530
428
|
0, 0, 0, 0)
|
|
@@ -546,27 +444,27 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
|
|
|
546
444
|
return results # return result images with default key names (i.e., net names)
|
|
547
445
|
|
|
548
446
|
|
|
549
|
-
def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='
|
|
447
|
+
def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='default', marker_thresh=None, size_thresh_upper=None):
|
|
550
448
|
if model == 'DeepLIIF':
|
|
551
449
|
resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
|
|
552
|
-
overlay, refined, scoring =
|
|
553
|
-
|
|
554
|
-
|
|
450
|
+
overlay, refined, scoring = compute_final_results(
|
|
451
|
+
orig, images['Seg'], images.get('Marker'), resolution,
|
|
452
|
+
size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
|
|
555
453
|
processed_images = {}
|
|
556
454
|
processed_images['SegOverlaid'] = Image.fromarray(overlay)
|
|
557
455
|
processed_images['SegRefined'] = Image.fromarray(refined)
|
|
558
456
|
return processed_images, scoring
|
|
559
457
|
|
|
560
|
-
elif model
|
|
458
|
+
elif model in ['DeepLIIFExt','SDG']:
|
|
561
459
|
resolution = '40x' if tile_size > 768 else ('20x' if tile_size > 384 else '10x')
|
|
562
460
|
processed_images = {}
|
|
563
461
|
scoring = {}
|
|
564
462
|
for img_name in list(images.keys()):
|
|
565
463
|
if 'Seg' in img_name:
|
|
566
464
|
seg_img = images[img_name]
|
|
567
|
-
overlay, refined, score =
|
|
568
|
-
|
|
569
|
-
|
|
465
|
+
overlay, refined, score = compute_final_results(
|
|
466
|
+
orig, images[img_name], None, resolution,
|
|
467
|
+
size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
|
|
570
468
|
|
|
571
469
|
processed_images[img_name + '_Overlaid'] = Image.fromarray(overlay)
|
|
572
470
|
processed_images[img_name + '_Refined'] = Image.fromarray(refined)
|
|
@@ -578,7 +476,8 @@ def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='aut
|
|
|
578
476
|
|
|
579
477
|
|
|
580
478
|
def infer_modalities(img, tile_size, model_dir, eager_mode=False,
|
|
581
|
-
color_dapi=False, color_marker=False, opt=None
|
|
479
|
+
color_dapi=False, color_marker=False, opt=None,
|
|
480
|
+
return_seg_intermediate=False, seg_only=False):
|
|
582
481
|
"""
|
|
583
482
|
This function is used to infer modalities for the given image using a trained model.
|
|
584
483
|
:param img: The input image.
|
|
@@ -591,11 +490,6 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
|
|
|
591
490
|
opt.use_dp = False
|
|
592
491
|
#print_options(opt)
|
|
593
492
|
|
|
594
|
-
if not tile_size:
|
|
595
|
-
tile_size = check_multi_scale(Image.open('./images/target.png').convert('L'),
|
|
596
|
-
img.convert('L'))
|
|
597
|
-
tile_size = int(tile_size)
|
|
598
|
-
|
|
599
493
|
# for those with multiple input modalities, find the correct size to calculate overlap_size
|
|
600
494
|
input_no = opt.input_no if hasattr(opt, 'input_no') else 1
|
|
601
495
|
img_size = (img.size[0] / input_no, img.size[1]) # (width, height)
|
|
@@ -609,18 +503,24 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
|
|
|
609
503
|
eager_mode=eager_mode,
|
|
610
504
|
color_dapi=color_dapi,
|
|
611
505
|
color_marker=color_marker,
|
|
612
|
-
opt=opt
|
|
506
|
+
opt=opt,
|
|
507
|
+
return_seg_intermediate=return_seg_intermediate,
|
|
508
|
+
seg_only=seg_only
|
|
613
509
|
)
|
|
614
|
-
|
|
510
|
+
|
|
615
511
|
if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
|
|
616
512
|
post_images, scoring = postprocess(img, images, tile_size, opt.model)
|
|
617
513
|
images = {**images, **post_images}
|
|
514
|
+
if seg_only:
|
|
515
|
+
delete_keys = [k for k in images.keys() if 'Seg' not in k]
|
|
516
|
+
for name in delete_keys:
|
|
517
|
+
del images[name]
|
|
618
518
|
return images, scoring
|
|
619
519
|
else:
|
|
620
520
|
return images, None
|
|
621
521
|
|
|
622
522
|
|
|
623
|
-
def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000):
|
|
523
|
+
def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000, color_dapi=False, color_marker=False, seg_intermediate=False, seg_only=False):
|
|
624
524
|
"""
|
|
625
525
|
This function infers modalities and segmentation mask for the given WSI image. It
|
|
626
526
|
|
|
@@ -632,35 +532,229 @@ def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size,
|
|
|
632
532
|
:param region_size: The size of each individual region to be processed at once.
|
|
633
533
|
:return:
|
|
634
534
|
"""
|
|
635
|
-
|
|
535
|
+
basename, _ = os.path.splitext(filename)
|
|
536
|
+
results_dir = os.path.join(output_dir, basename)
|
|
636
537
|
if not os.path.exists(results_dir):
|
|
637
538
|
os.makedirs(results_dir)
|
|
638
539
|
size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(os.path.join(input_dir, filename))
|
|
639
|
-
|
|
540
|
+
rescale = (pixel_type != 'uint8')
|
|
541
|
+
print(filename, size_x, size_y, size_z, size_c, size_t, pixel_type, flush=True)
|
|
542
|
+
|
|
640
543
|
results = {}
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
544
|
+
scoring = None
|
|
545
|
+
|
|
546
|
+
# javabridge already set up from previous call to get_information()
|
|
547
|
+
with bioformats.ImageReader(os.path.join(input_dir, filename)) as reader:
|
|
548
|
+
start_x, start_y = 0, 0
|
|
549
|
+
|
|
550
|
+
while start_x < size_x:
|
|
551
|
+
while start_y < size_y:
|
|
552
|
+
print(start_x, start_y, flush=True)
|
|
553
|
+
region_XYWH = (start_x, start_y, min(region_size, size_x - start_x), min(region_size, size_y - start_y))
|
|
554
|
+
region = reader.read(XYWH=region_XYWH, rescale=rescale)
|
|
555
|
+
img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
|
|
556
|
+
|
|
557
|
+
region_modalities, region_scoring = infer_modalities(img, tile_size, model_dir, color_dapi=color_dapi, color_marker=color_marker, return_seg_intermediate=seg_intermediate, seg_only=seg_only)
|
|
558
|
+
if region_scoring is not None:
|
|
559
|
+
if scoring is None:
|
|
560
|
+
scoring = {
|
|
561
|
+
'num_pos': region_scoring['num_pos'],
|
|
562
|
+
'num_neg': region_scoring['num_neg'],
|
|
563
|
+
}
|
|
564
|
+
else:
|
|
565
|
+
scoring['num_pos'] += region_scoring['num_pos']
|
|
566
|
+
scoring['num_neg'] += region_scoring['num_neg']
|
|
567
|
+
|
|
568
|
+
for name, img in region_modalities.items():
|
|
569
|
+
if name not in results:
|
|
570
|
+
results[name] = np.zeros((size_y, size_x, 3), dtype=np.uint8)
|
|
571
|
+
results[name][region_XYWH[1]: region_XYWH[1] + region_XYWH[3],
|
|
572
|
+
region_XYWH[0]: region_XYWH[0] + region_XYWH[2]] = np.array(img)
|
|
573
|
+
start_y += region_size
|
|
574
|
+
start_y = 0
|
|
575
|
+
start_x += region_size
|
|
576
|
+
|
|
577
|
+
# write_results_to_pickle_file(os.path.join(results_dir, "results.pickle"), results)
|
|
660
578
|
# read_results_from_pickle_file(os.path.join(results_dir, "results.pickle"))
|
|
661
579
|
|
|
662
580
|
for name, img in results.items():
|
|
663
|
-
write_big_tiff_file(os.path.join(results_dir,
|
|
664
|
-
|
|
581
|
+
write_big_tiff_file(os.path.join(results_dir, f'{basename}_{name}.ome.tiff'), img, tile_size)
|
|
582
|
+
|
|
583
|
+
if scoring is not None:
|
|
584
|
+
scoring['num_total'] = scoring['num_pos'] + scoring['num_neg']
|
|
585
|
+
scoring['percent_pos'] = round(scoring['num_pos'] / scoring['num_total'] * 100, 1) if scoring['num_pos'] > 0 else 0
|
|
586
|
+
with open(os.path.join(results_dir, f'{basename}.json'), 'w') as f:
|
|
587
|
+
json.dump(scoring, f, indent=2)
|
|
665
588
|
|
|
666
589
|
javabridge.kill_vm()
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
def get_wsi_resolution(filename):
|
|
593
|
+
"""
|
|
594
|
+
Try to get the resolution (magnification) of the slide and
|
|
595
|
+
the corresponding tile size to use by default for DeepLIIF.
|
|
596
|
+
If it cannot be found, return (None, None) instead.
|
|
597
|
+
|
|
598
|
+
Note: This will start the javabridge VM, but not kill it.
|
|
599
|
+
It must be killed elsewhere.
|
|
600
|
+
|
|
601
|
+
Parameters
|
|
602
|
+
----------
|
|
603
|
+
filename : str
|
|
604
|
+
Full path to the file.
|
|
605
|
+
|
|
606
|
+
Returns
|
|
607
|
+
-------
|
|
608
|
+
str :
|
|
609
|
+
Magnification (objective power) from image metadata.
|
|
610
|
+
int :
|
|
611
|
+
Corresponding tile size for DeepLIIF.
|
|
612
|
+
"""
|
|
613
|
+
|
|
614
|
+
# make sure javabridge is already set up from with call to get_information()
|
|
615
|
+
size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
|
|
616
|
+
|
|
617
|
+
mag = None
|
|
618
|
+
metadata = bioformats.get_omexml_metadata(filename)
|
|
619
|
+
try:
|
|
620
|
+
omexml = bioformats.OMEXML(metadata)
|
|
621
|
+
mag = omexml.instrument().Objective.NominalMagnification
|
|
622
|
+
except Exception as e:
|
|
623
|
+
fields = ['AppMag', 'NominalMagnification']
|
|
624
|
+
try:
|
|
625
|
+
for field in fields:
|
|
626
|
+
idx = metadata.find(field)
|
|
627
|
+
if idx >= 0:
|
|
628
|
+
for i in range(idx, len(metadata)):
|
|
629
|
+
if metadata[i].isdigit() or metadata[i] == '.':
|
|
630
|
+
break
|
|
631
|
+
for j in range(i, len(metadata)):
|
|
632
|
+
if not metadata[j].isdigit() and metadata[j] != '.':
|
|
633
|
+
break
|
|
634
|
+
if i == j:
|
|
635
|
+
continue
|
|
636
|
+
mag = metadata[i:j]
|
|
637
|
+
break
|
|
638
|
+
except Exception as e:
|
|
639
|
+
pass
|
|
640
|
+
|
|
641
|
+
if mag is None:
|
|
642
|
+
return None, None
|
|
643
|
+
|
|
644
|
+
try:
|
|
645
|
+
tile_size = round((float(mag) / 40) * 512)
|
|
646
|
+
return mag, tile_size
|
|
647
|
+
except Exception as e:
|
|
648
|
+
return None, None
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False):
|
|
652
|
+
"""
|
|
653
|
+
Perform inference on a slide and get the results individual cell data.
|
|
654
|
+
|
|
655
|
+
Parameters
|
|
656
|
+
----------
|
|
657
|
+
filename : str
|
|
658
|
+
Full path to the file.
|
|
659
|
+
model_dir : str
|
|
660
|
+
Full path to the directory with the DeepLIIF model files.
|
|
661
|
+
tile_size : int
|
|
662
|
+
Size of tiles to extract and perform inference on.
|
|
663
|
+
region_size : int
|
|
664
|
+
Maximum size to split the slide for processing.
|
|
665
|
+
version : int
|
|
666
|
+
Version of cell data to return (3 or 4).
|
|
667
|
+
print_log : bool
|
|
668
|
+
Whether or not to print updates while processing.
|
|
669
|
+
|
|
670
|
+
Returns
|
|
671
|
+
-------
|
|
672
|
+
dict :
|
|
673
|
+
Individual cell data and associated values.
|
|
674
|
+
"""
|
|
675
|
+
|
|
676
|
+
def print_info(*args):
|
|
677
|
+
if print_log:
|
|
678
|
+
print(*args, flush=True)
|
|
679
|
+
|
|
680
|
+
resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
|
|
681
|
+
|
|
682
|
+
size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
|
|
683
|
+
rescale = (pixel_type != 'uint8')
|
|
684
|
+
print_info('Info:', size_x, size_y, size_z, size_c, size_t, pixel_type)
|
|
685
|
+
|
|
686
|
+
num_regions_x = math.ceil(size_x / region_size)
|
|
687
|
+
num_regions_y = math.ceil(size_y / region_size)
|
|
688
|
+
stride_x = math.ceil(size_x / num_regions_x)
|
|
689
|
+
stride_y = math.ceil(size_y / num_regions_y)
|
|
690
|
+
print_info('Strides:', stride_x, stride_y)
|
|
691
|
+
|
|
692
|
+
data = None
|
|
693
|
+
default_marker_thresh, count_marker_thresh = 0, 0
|
|
694
|
+
default_size_thresh, count_size_thresh = 0, 0
|
|
695
|
+
|
|
696
|
+
# javabridge already set up from previous call to get_information()
|
|
697
|
+
with bioformats.ImageReader(filename) as reader:
|
|
698
|
+
start_x, start_y = 0, 0
|
|
699
|
+
|
|
700
|
+
while start_y < size_y:
|
|
701
|
+
while start_x < size_x:
|
|
702
|
+
region_XYWH = (start_x, start_y, min(stride_x, size_x-start_x), min(stride_y, size_y-start_y))
|
|
703
|
+
print_info('Region:', region_XYWH)
|
|
704
|
+
|
|
705
|
+
region = reader.read(XYWH=region_XYWH, rescale=rescale)
|
|
706
|
+
print_info(region.shape, region.dtype)
|
|
707
|
+
img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
|
|
708
|
+
print_info(img.size, img.mode)
|
|
709
|
+
|
|
710
|
+
images = inference(
|
|
711
|
+
img,
|
|
712
|
+
tile_size=tile_size,
|
|
713
|
+
overlap_size=tile_size//16,
|
|
714
|
+
model_path=model_dir,
|
|
715
|
+
eager_mode=False,
|
|
716
|
+
color_dapi=False,
|
|
717
|
+
color_marker=False,
|
|
718
|
+
opt=None,
|
|
719
|
+
return_seg_intermediate=False,
|
|
720
|
+
seg_only=True,
|
|
721
|
+
)
|
|
722
|
+
region_data = compute_cell_results(images['Seg'], images.get('Marker'), resolution, version=version)
|
|
723
|
+
|
|
724
|
+
if start_x != 0 or start_y != 0:
|
|
725
|
+
for i in range(len(region_data['cells'])):
|
|
726
|
+
cell = decode_cell_data_v4(region_data['cells'][i]) if version == 4 else region_data['cells'][i]
|
|
727
|
+
for j in range(2):
|
|
728
|
+
cell['bbox'][j] = (cell['bbox'][j][0] + start_x, cell['bbox'][j][1] + start_y)
|
|
729
|
+
cell['centroid'] = (cell['centroid'][0] + start_x, cell['centroid'][1] + start_y)
|
|
730
|
+
for j in range(len(cell['boundary'])):
|
|
731
|
+
cell['boundary'][j] = (cell['boundary'][j][0] + start_x, cell['boundary'][j][1] + start_y)
|
|
732
|
+
region_data['cells'][i] = encode_cell_data_v4(cell) if version == 4 else cell
|
|
733
|
+
|
|
734
|
+
if data is None:
|
|
735
|
+
data = region_data
|
|
736
|
+
else:
|
|
737
|
+
data['cells'] += region_data['cells']
|
|
738
|
+
|
|
739
|
+
if region_data['settings']['default_marker_thresh'] is not None and region_data['settings']['default_marker_thresh'] != 0:
|
|
740
|
+
default_marker_thresh += region_data['settings']['default_marker_thresh']
|
|
741
|
+
count_marker_thresh += 1
|
|
742
|
+
if region_data['settings']['default_size_thresh'] != 0:
|
|
743
|
+
default_size_thresh += region_data['settings']['default_size_thresh']
|
|
744
|
+
count_size_thresh += 1
|
|
745
|
+
|
|
746
|
+
start_x += stride_x
|
|
747
|
+
|
|
748
|
+
start_x = 0
|
|
749
|
+
start_y += stride_y
|
|
750
|
+
|
|
751
|
+
javabridge.kill_vm()
|
|
752
|
+
|
|
753
|
+
if count_marker_thresh == 0:
|
|
754
|
+
count_marker_thresh = 1
|
|
755
|
+
if count_size_thresh == 0:
|
|
756
|
+
count_size_thresh = 1
|
|
757
|
+
data['settings']['default_marker_thresh'] = round(default_marker_thresh / count_marker_thresh)
|
|
758
|
+
data['settings']['default_size_thresh'] = round(default_size_thresh / count_size_thresh)
|
|
759
|
+
|
|
760
|
+
return data
|