deepliif 1.1.11__py3-none-any.whl → 1.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,7 +12,7 @@ In the function <__init__>, you need to define four lists:
12
12
  -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
13
  -- self.model_names (str list): define networks used in our training.
14
14
  -- self.visual_names (str list): specify the images that you want to display and save.
15
- -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See CycleGAN_model.py for an usage.
16
16
 
17
17
  Now you can use the model class by specifying flag '--model dummy'.
18
18
  See our template model class 'template_model.py' for more details.
@@ -23,17 +23,22 @@ import itertools
23
23
  import importlib
24
24
  from functools import lru_cache
25
25
  from io import BytesIO
26
+ import json
27
+ import math
26
28
 
27
29
  import requests
28
30
  import torch
29
31
  from PIL import Image
32
+ Image.MAX_IMAGE_PIXELS = None
33
+
30
34
  import numpy as np
31
35
  from dask import delayed, compute
32
36
 
33
37
  from deepliif.util import *
34
- from deepliif.util.util import tensor_to_pil, check_multi_scale
38
+ from deepliif.util.util import tensor_to_pil
35
39
  from deepliif.data import transform
36
- from deepliif.postprocessing import compute_results
40
+ from deepliif.postprocessing import compute_final_results, compute_cell_results
41
+ from deepliif.postprocessing import encode_cell_data_v4, decode_cell_data_v4
37
42
  from deepliif.options import Options, print_options
38
43
 
39
44
  from .base_model import BaseModel
@@ -115,14 +120,19 @@ def load_torchscript_model(model_pt_path, device):
115
120
 
116
121
 
117
122
 
118
- def load_eager_models(opt, devices):
123
+ def load_eager_models(opt, devices=None):
119
124
  # create a model given model and other options
120
125
  model = create_model(opt)
121
126
  # regular setup: load and print networks; create schedulers
122
127
  model.setup(opt)
123
128
 
124
129
  nets = {}
125
- for name in model.model_names:
130
+ if devices:
131
+ model_names = list(devices.keys())
132
+ else:
133
+ model_names = model.model_names
134
+
135
+ for name in model_names:#model.model_names:
126
136
  if isinstance(name, str):
127
137
  if '_' in name:
128
138
  net = getattr(model, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
@@ -138,7 +148,8 @@ def load_eager_models(opt, devices):
138
148
  net = net.module
139
149
 
140
150
  nets[name] = net
141
- nets[name].to(devices[name])
151
+ if devices:
152
+ nets[name].to(devices[name])
142
153
 
143
154
  return nets
144
155
 
@@ -154,8 +165,7 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
154
165
  """
155
166
  if opt is None:
156
167
  opt = get_opt(model_dir, mode=phase)
157
- opt.use_dp = False
158
- #print_options(opt)
168
+ opt.use_dp = False
159
169
 
160
170
  if opt.model == 'DeepLIIF':
161
171
  net_groups = [
@@ -170,12 +180,16 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
170
180
  net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
171
181
  else:
172
182
  net_groups = [(f'G_{i+1}',) for i in range(opt.modalities_no)]
183
+ elif opt.model == 'CycleGAN':
184
+ if opt.BtoA:
185
+ net_groups = [(f'GB_{i+1}',) for i in range(opt.modalities_no)]
186
+ else:
187
+ net_groups = [(f'GA_{i+1}',) for i in range(opt.modalities_no)]
173
188
  else:
174
189
  raise Exception(f'init_nets() not implemented for model {opt.model}')
175
190
 
176
191
  number_of_gpus_all = torch.cuda.device_count()
177
- number_of_gpus = len(opt.gpu_ids)
178
- #print(number_of_gpus)
192
+ number_of_gpus = min(len(opt.gpu_ids),number_of_gpus_all)
179
193
 
180
194
  if number_of_gpus > 0:
181
195
  mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
@@ -203,12 +217,13 @@ def compute_overlap(img_size, tile_size):
203
217
  return tile_size // 4
204
218
 
205
219
 
206
- def run_torchserve(img, model_path=None, eager_mode=False, opt=None):
220
+ def run_torchserve(img, model_path=None, eager_mode=False, opt=None, seg_only=False):
207
221
  """
208
222
  eager_mode: not used in this function; put in place to be consistent with run_dask
209
223
  so that run_wrapper() could call either this function or run_dask with
210
224
  same syntax
211
225
  opt: same as eager_mode
226
+ seg_only: same as eager_mode
212
227
  """
213
228
  buffer = BytesIO()
214
229
  torch.save(transform(img.resize((opt.scale_size, opt.scale_size))), buffer)
@@ -227,9 +242,10 @@ def run_torchserve(img, model_path=None, eager_mode=False, opt=None):
227
242
  return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
228
243
 
229
244
 
230
- def run_dask(img, model_path, eager_mode=False, opt=None):
245
+ def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
231
246
  model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
232
247
  nets = init_nets(model_dir, eager_mode, opt)
248
+ use_dask = True if opt.norm != 'spectral' else False
233
249
 
234
250
  if opt.input_no > 1 or opt.model == 'SDG':
235
251
  l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
@@ -238,45 +254,69 @@ def run_dask(img, model_path, eager_mode=False, opt=None):
238
254
  ts = transform(img.resize((opt.scale_size, opt.scale_size)))
239
255
 
240
256
 
241
- @delayed
242
- def forward(input, model):
243
- with torch.no_grad():
244
- return model(input.to(next(model.parameters()).device))
257
+ if use_dask:
258
+ @delayed
259
+ def forward(input, model):
260
+ with torch.no_grad():
261
+ return model(input.to(next(model.parameters()).device))
262
+ else: # some train settings like spectral norm some how in inference mode is not compatible with dask
263
+ def forward(input, model):
264
+ with torch.no_grad():
265
+ return model(input.to(next(model.parameters()).device))
245
266
 
246
267
  if opt.model == 'DeepLIIF':
268
+ weights = {
269
+ 'G51': 0.25, # IHC
270
+ 'G52': 0.25, # Hema
271
+ 'G53': 0.25, # DAPI
272
+ 'G54': 0.00, # Lap2
273
+ 'G55': 0.25, # Marker
274
+ }
275
+
247
276
  seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
277
+ if seg_only:
278
+ seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0}
248
279
 
249
280
  lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
281
+ if 'G4' not in seg_map:
282
+ lazy_gens['G4'] = forward(ts, nets['G4'])
250
283
  gens = compute(lazy_gens)[0]
251
284
 
252
285
  lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
253
- lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
286
+ if not seg_only or weights['G51'] != 0:
287
+ lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
254
288
  segs = compute(lazy_segs)[0]
255
289
 
256
- weights = {
257
- 'G51': 0.25, # IHC
258
- 'G52': 0.25, # Hema
259
- 'G53': 0.25, # DAPI
260
- 'G54': 0.00, # Lap2
261
- 'G55': 0.25, # Marker
262
- }
263
290
  seg = torch.stack([torch.mul(segs[k], weights[k]) for k in segs.keys()]).sum(dim=0)
264
291
 
265
- res = {k: tensor_to_pil(v) for k, v in gens.items()}
292
+ if seg_only:
293
+ res = {'G4': tensor_to_pil(gens['G4'])} if 'G4' in gens else {}
294
+ else:
295
+ res = {k: tensor_to_pil(v) for k, v in gens.items()}
296
+ res.update({k: tensor_to_pil(v) for k, v in segs.items()})
266
297
  res['G5'] = tensor_to_pil(seg)
267
298
 
268
299
  return res
269
- elif opt.model in ['DeepLIIFExt','SDG']:
270
- seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
300
+ elif opt.model in ['DeepLIIFExt','SDG','CycleGAN']:
301
+ if opt.model == 'CycleGAN':
302
+ seg_map = {f'GB_{i+1}':None for i in range(opt.modalities_no)} if opt.BtoA else {f'GA_{i+1}':None for i in range(opt.modalities_no)}
303
+ else:
304
+ seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
271
305
 
272
- lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
273
- gens = compute(lazy_gens)[0]
306
+ if use_dask:
307
+ lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
308
+ gens = compute(lazy_gens)[0]
309
+ else:
310
+ gens = {k: forward(ts, nets[k]) for k in seg_map}
274
311
 
275
312
  res = {k: tensor_to_pil(v) for k, v in gens.items()}
276
313
 
277
314
  if opt.seg_gen:
278
- lazy_segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
279
- segs = compute(lazy_segs)[0]
315
+ if use_dask:
316
+ lazy_segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
317
+ segs = compute(lazy_segs)[0]
318
+ else:
319
+ segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
280
320
  res.update({k: tensor_to_pil(v) for k, v in segs.items()})
281
321
 
282
322
  return res
@@ -284,14 +324,6 @@ def run_dask(img, model_path, eager_mode=False, opt=None):
284
324
  raise Exception(f'run_dask() not fully implemented for {opt.model}')
285
325
 
286
326
 
287
- def is_empty_old(tile):
288
- # return True if np.mean(np.array(tile) - np.array(mean_background_val)) < 40 else False
289
- if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
290
- return all([True if calculate_background_area(t) > 98 else False for t in tile])
291
- else:
292
- return True if calculate_background_area(tile) > 98 else False
293
-
294
-
295
327
  def is_empty(tile):
296
328
  thresh = 15
297
329
  if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
@@ -300,18 +332,29 @@ def is_empty(tile):
300
332
  return True if np.max(image_variance_rgb(tile)) < thresh else False
301
333
 
302
334
 
303
- def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None):
335
+ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None, seg_only=False):
304
336
  if opt.model == 'DeepLIIF':
305
337
  if is_empty(tile):
306
- return {
307
- 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
308
- 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
309
- 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
310
- 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
311
- 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0))
312
- }
338
+ if seg_only:
339
+ return {
340
+ 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
341
+ 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
342
+ }
343
+ else :
344
+ return {
345
+ 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
346
+ 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
347
+ 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
348
+ 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
349
+ 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
350
+ 'G51': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
351
+ 'G52': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
352
+ 'G53': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
353
+ 'G54': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
354
+ 'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
355
+ }
313
356
  else:
314
- return run_fn(tile, model_path, eager_mode, opt)
357
+ return run_fn(tile, model_path, eager_mode, opt, seg_only)
315
358
  elif opt.model in ['DeepLIIFExt', 'SDG']:
316
359
  if is_empty(tile):
317
360
  res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
@@ -319,178 +362,20 @@ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None):
319
362
  return res
320
363
  else:
321
364
  return run_fn(tile, model_path, eager_mode, opt)
365
+ elif opt.model in ['CycleGAN']:
366
+ if is_empty(tile):
367
+ net_names = ['GB_{i+1}' for i in range(opt.modalities_no)] if opt.BtoA else [f'GA_{i+1}' for i in range(opt.modalities_no)]
368
+ res = {net_name: Image.new(mode='RGB', size=(512, 512)) for net_name in net_names}
369
+ return res
370
+ else:
371
+ return run_fn(tile, model_path, eager_mode, opt)
322
372
  else:
323
373
  raise Exception(f'run_wrapper() not implemented for model {opt.model}')
324
374
 
325
375
 
326
- def inference_old(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
327
- color_dapi=False, color_marker=False):
328
-
329
- tiles = list(generate_tiles(img, tile_size, overlap_size))
330
-
331
- run_fn = run_torchserve if use_torchserve else run_dask
332
- # res = [Tile(t.i, t.j, run_fn(t.img, model_path)) for t in tiles]
333
- res = [Tile(t.i, t.j, run_wrapper(t.img, run_fn, model_path, eager_mode)) for t in tiles]
334
-
335
- def get_net_tiles(n):
336
- return [Tile(t.i, t.j, t.img[n]) for t in res]
337
-
338
- images = {}
339
-
340
- images['Hema'] = stitch(get_net_tiles('G1'), tile_size, overlap_size).resize(img.size)
341
-
342
- # images['DAPI'] = stitch(
343
- # [Tile(t.i, t.j, adjust_background_tile(dt.img))
344
- # for t, dt in zip(tiles, get_net_tiles('G2'))],
345
- # tile_size, overlap_size).resize(img.size)
346
- # dapi_pix = np.array(images['DAPI'])
347
- # dapi_pix[:, :, 0] = 0
348
- # images['DAPI'] = Image.fromarray(dapi_pix)
349
-
350
- images['DAPI'] = stitch(get_net_tiles('G2'), tile_size, overlap_size).resize(img.size)
351
- dapi_pix = np.array(images['DAPI'].convert('L').convert('RGB'))
352
- if color_dapi:
353
- dapi_pix[:, :, 0] = 0
354
- images['DAPI'] = Image.fromarray(dapi_pix)
355
- images['Lap2'] = stitch(get_net_tiles('G3'), tile_size, overlap_size).resize(img.size)
356
- images['Marker'] = stitch(get_net_tiles('G4'), tile_size, overlap_size).resize(img.size)
357
- marker_pix = np.array(images['Marker'].convert('L').convert('RGB'))
358
- if color_marker:
359
- marker_pix[:, :, 2] = 0
360
- images['Marker'] = Image.fromarray(marker_pix)
361
-
362
- # images['Marker'] = stitch(
363
- # [Tile(t.i, t.j, kt.img)
364
- # for t, kt in zip(tiles, get_net_tiles('G4'))],
365
- # tile_size, overlap_size).resize(img.size)
366
-
367
- images['Seg'] = stitch(get_net_tiles('G5'), tile_size, overlap_size).resize(img.size)
368
-
369
- return images
370
-
371
-
372
- def inference_old2(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
373
- color_dapi=False, color_marker=False, opt=None):
374
- if not opt:
375
- opt = get_opt(model_path)
376
- #print_options(opt)
377
-
378
- if opt.model == 'DeepLIIF':
379
- rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
380
-
381
- run_fn = run_torchserve if use_torchserve else run_dask
382
-
383
- images = {}
384
- d_modality2net = {'Hema':'G1',
385
- 'DAPI':'G2',
386
- 'Lap2':'G3',
387
- 'Marker':'G4',
388
- 'Seg':'G5'}
389
-
390
- for k in d_modality2net.keys():
391
- images[k] = create_image_for_stitching(tile_size, rows, cols)
392
-
393
- for i in range(cols):
394
- for j in range(rows):
395
- tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
396
- res = run_wrapper(tile, run_fn, model_path, eager_mode, opt)
397
-
398
- for modality_name, net_name in d_modality2net.items():
399
- stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
400
-
401
- for modality_name, output_img in images.items():
402
- images[modality_name] = output_img.resize(img.size)
403
-
404
- if color_dapi:
405
- matrix = ( 0, 0, 0, 0,
406
- 299/1000, 587/1000, 114/1000, 0,
407
- 299/1000, 587/1000, 114/1000, 0)
408
- images['DAPI'] = images['DAPI'].convert('RGB', matrix)
409
-
410
- if color_marker:
411
- matrix = (299/1000, 587/1000, 114/1000, 0,
412
- 299/1000, 587/1000, 114/1000, 0,
413
- 0, 0, 0, 0)
414
- images['Marker'] = images['Marker'].convert('RGB', matrix)
415
-
416
- return images
417
-
418
- elif opt.model == 'DeepLIIFExt':
419
- #param_dict = read_train_options(model_path)
420
- #modalities_no = int(param_dict['modalities_no']) if param_dict else 4
421
- #seg_gen = (param_dict['seg_gen'] == 'True') if param_dict else True
422
-
423
-
424
- rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
425
- run_fn = run_torchserve if use_torchserve else run_dask
426
-
427
- def get_net_tiles(n):
428
- return [Tile(t.i, t.j, t.img[n]) for t in res]
429
-
430
- images = {}
431
- d_modality2net = {f'mod{i}':f'G_{i}' for i in range(1, opt.modalities_no + 1)}
432
- if opt.seg_gen:
433
- d_modality2net.update({f'Seg{i}':f'GS_{i}' for i in range(1, opt.modalities_no + 1)})
434
-
435
- for k in d_modality2net.keys():
436
- images[k] = create_image_for_stitching(tile_size, rows, cols)
437
-
438
- for i in range(cols):
439
- for j in range(rows):
440
- tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
441
- res = run_wrapper(tile, run_fn, model_path, eager_mode, opt)
442
-
443
- for modality_name, net_name in d_modality2net.items():
444
- stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
445
-
446
- for modality_name, output_img in images.items():
447
- images[modality_name] = output_img.resize(img.size)
448
-
449
- return images
450
-
451
- elif opt.model == 'SDG':
452
- # SDG could have multiple input images / modalities
453
- # the input hence could be a rectangle
454
- # we split the input to get each modality image one by one
455
- # then create tiles for each of the modality images
456
- # tile_pair is a list that contains the tiles at the given location for each modality image
457
- # l_tile_pair is a list of tile_pair that covers all locations
458
- # for inference, each tile_pair is used to get the output at the given location
459
- w, h = img.size
460
- w2 = int(w / opt.input_no)
461
-
462
- l_img = []
463
- for i in range(opt.input_no):
464
- img_i = img.crop((w2 * i, 0, w2 * (i+1), h))
465
- rescaled_img_i, rows, cols = format_image_for_tiling(img_i, tile_size, overlap_size)
466
- l_img.append(rescaled_img_i)
467
-
468
- run_fn = run_torchserve if use_torchserve else run_dask
469
-
470
- images = {}
471
- d_modality2net = {f'mod{i}':f'G_{i}' for i in range(1, opt.modalities_no + 1)}
472
- for k in d_modality2net.keys():
473
- images[k] = create_image_for_stitching(tile_size, rows, cols)
474
-
475
- for i in range(cols):
476
- for j in range(rows):
477
- tile_pair = [extract_tile(rescaled, tile_size, overlap_size, i, j) for rescaled in l_img]
478
- res = run_wrapper(tile_pair, run_fn, model_path, eager_mode, opt)
479
-
480
- for modality_name, net_name in d_modality2net.items():
481
- stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
482
-
483
- for modality_name, output_img in images.items():
484
- images[modality_name] = output_img.resize((w2,w2))
485
-
486
- return images
487
-
488
- else:
489
- raise Exception(f'inference() not implemented for model {opt.model}')
490
-
491
-
492
376
  def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
493
- eager_mode=False, color_dapi=False, color_marker=False, opt=None):
377
+ eager_mode=False, color_dapi=False, color_marker=False, opt=None,
378
+ return_seg_intermediate=False, seg_only=False):
494
379
  if not opt:
495
380
  opt = get_opt(model_path)
496
381
  #print_options(opt)
@@ -508,23 +393,36 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
508
393
 
509
394
  tiler = InferenceTiler(orig, tile_size, overlap_size)
510
395
  for tile in tiler:
511
- tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt))
396
+ tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt, seg_only))
512
397
  results = tiler.results()
513
398
 
514
399
  if opt.model == 'DeepLIIF':
515
- images = {
516
- 'Hema': results['G1'],
517
- 'DAPI': results['G2'],
518
- 'Lap2': results['G3'],
519
- 'Marker': results['G4'],
520
- 'Seg': results['G5'],
521
- }
522
- if color_dapi:
400
+ if seg_only:
401
+ images = {'Seg': results['G5']}
402
+ if 'G4' in results:
403
+ images.update({'Marker': results['G4']})
404
+ else:
405
+ images = {
406
+ 'Hema': results['G1'],
407
+ 'DAPI': results['G2'],
408
+ 'Lap2': results['G3'],
409
+ 'Marker': results['G4'],
410
+ 'Seg': results['G5'],
411
+ }
412
+
413
+ if return_seg_intermediate and not seg_only:
414
+ images.update({'IHC_s':results['G51'],
415
+ 'Hema_s':results['G52'],
416
+ 'DAPI_s':results['G53'],
417
+ 'Lap2_s':results['G54'],
418
+ 'Marker_s':results['G55'],})
419
+
420
+ if color_dapi and not seg_only:
523
421
  matrix = ( 0, 0, 0, 0,
524
422
  299/1000, 587/1000, 114/1000, 0,
525
423
  299/1000, 587/1000, 114/1000, 0)
526
424
  images['DAPI'] = images['DAPI'].convert('RGB', matrix)
527
- if color_marker:
425
+ if color_marker and not seg_only:
528
426
  matrix = (299/1000, 587/1000, 114/1000, 0,
529
427
  299/1000, 587/1000, 114/1000, 0,
530
428
  0, 0, 0, 0)
@@ -546,27 +444,27 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
546
444
  return results # return result images with default key names (i.e., net names)
547
445
 
548
446
 
549
- def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='auto', marker_thresh='auto', size_thresh_upper=None):
447
+ def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='default', marker_thresh=None, size_thresh_upper=None):
550
448
  if model == 'DeepLIIF':
551
449
  resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
552
- overlay, refined, scoring = compute_results(np.array(orig), np.array(images['Seg']),
553
- np.array(images['Marker'].convert('L')) if 'Marker' in images else None,
554
- resolution, seg_thresh, size_thresh, marker_thresh, size_thresh_upper)
450
+ overlay, refined, scoring = compute_final_results(
451
+ orig, images['Seg'], images.get('Marker'), resolution,
452
+ size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
555
453
  processed_images = {}
556
454
  processed_images['SegOverlaid'] = Image.fromarray(overlay)
557
455
  processed_images['SegRefined'] = Image.fromarray(refined)
558
456
  return processed_images, scoring
559
457
 
560
- elif model == 'DeepLIIFExt':
458
+ elif model in ['DeepLIIFExt','SDG']:
561
459
  resolution = '40x' if tile_size > 768 else ('20x' if tile_size > 384 else '10x')
562
460
  processed_images = {}
563
461
  scoring = {}
564
462
  for img_name in list(images.keys()):
565
463
  if 'Seg' in img_name:
566
464
  seg_img = images[img_name]
567
- overlay, refined, score = compute_results(np.array(orig), np.array(images[img_name]),
568
- None, resolution,
569
- seg_thresh, size_thresh, marker_thresh, size_thresh_upper)
465
+ overlay, refined, score = compute_final_results(
466
+ orig, images[img_name], None, resolution,
467
+ size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
570
468
 
571
469
  processed_images[img_name + '_Overlaid'] = Image.fromarray(overlay)
572
470
  processed_images[img_name + '_Refined'] = Image.fromarray(refined)
@@ -578,7 +476,8 @@ def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='aut
578
476
 
579
477
 
580
478
  def infer_modalities(img, tile_size, model_dir, eager_mode=False,
581
- color_dapi=False, color_marker=False, opt=None):
479
+ color_dapi=False, color_marker=False, opt=None,
480
+ return_seg_intermediate=False, seg_only=False):
582
481
  """
583
482
  This function is used to infer modalities for the given image using a trained model.
584
483
  :param img: The input image.
@@ -591,11 +490,6 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
591
490
  opt.use_dp = False
592
491
  #print_options(opt)
593
492
 
594
- if not tile_size:
595
- tile_size = check_multi_scale(Image.open('./images/target.png').convert('L'),
596
- img.convert('L'))
597
- tile_size = int(tile_size)
598
-
599
493
  # for those with multiple input modalities, find the correct size to calculate overlap_size
600
494
  input_no = opt.input_no if hasattr(opt, 'input_no') else 1
601
495
  img_size = (img.size[0] / input_no, img.size[1]) # (width, height)
@@ -609,18 +503,24 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
609
503
  eager_mode=eager_mode,
610
504
  color_dapi=color_dapi,
611
505
  color_marker=color_marker,
612
- opt=opt
506
+ opt=opt,
507
+ return_seg_intermediate=return_seg_intermediate,
508
+ seg_only=seg_only
613
509
  )
614
-
510
+
615
511
  if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
616
512
  post_images, scoring = postprocess(img, images, tile_size, opt.model)
617
513
  images = {**images, **post_images}
514
+ if seg_only:
515
+ delete_keys = [k for k in images.keys() if 'Seg' not in k]
516
+ for name in delete_keys:
517
+ del images[name]
618
518
  return images, scoring
619
519
  else:
620
520
  return images, None
621
521
 
622
522
 
623
- def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000):
523
+ def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000, color_dapi=False, color_marker=False, seg_intermediate=False, seg_only=False):
624
524
  """
625
525
  This function infers modalities and segmentation mask for the given WSI image. It
626
526
 
@@ -632,35 +532,229 @@ def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size,
632
532
  :param region_size: The size of each individual region to be processed at once.
633
533
  :return:
634
534
  """
635
- results_dir = os.path.join(output_dir, filename)
535
+ basename, _ = os.path.splitext(filename)
536
+ results_dir = os.path.join(output_dir, basename)
636
537
  if not os.path.exists(results_dir):
637
538
  os.makedirs(results_dir)
638
539
  size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(os.path.join(input_dir, filename))
639
- print(filename, size_x, size_y, size_z, size_c, size_t, pixel_type)
540
+ rescale = (pixel_type != 'uint8')
541
+ print(filename, size_x, size_y, size_z, size_c, size_t, pixel_type, flush=True)
542
+
640
543
  results = {}
641
- start_x, start_y = 0, 0
642
- while start_x < size_x:
643
- while start_y < size_y:
644
- print(start_x, start_y)
645
- region_XYWH = (start_x, start_y, min(region_size, size_x - start_x), min(region_size, size_y - start_y))
646
- region = read_bioformats_image_with_reader(os.path.join(input_dir, filename), region=region_XYWH)
647
-
648
- region_modalities, region_scoring = infer_modalities(Image.fromarray((region * 255).astype(np.uint8)), tile_size, model_dir)
649
-
650
- for name, img in region_modalities.items():
651
- if name not in results:
652
- results[name] = np.zeros((size_y, size_x, 3), dtype=np.uint8)
653
- results[name][region_XYWH[1]: region_XYWH[1] + region_XYWH[3],
654
- region_XYWH[0]: region_XYWH[0] + region_XYWH[2]] = np.array(img)
655
- start_y += region_size
656
- start_y = 0
657
- start_x += region_size
658
-
659
- write_results_to_pickle_file(os.path.join(results_dir, "results.pickle"), results)
544
+ scoring = None
545
+
546
+ # javabridge already set up from previous call to get_information()
547
+ with bioformats.ImageReader(os.path.join(input_dir, filename)) as reader:
548
+ start_x, start_y = 0, 0
549
+
550
+ while start_x < size_x:
551
+ while start_y < size_y:
552
+ print(start_x, start_y, flush=True)
553
+ region_XYWH = (start_x, start_y, min(region_size, size_x - start_x), min(region_size, size_y - start_y))
554
+ region = reader.read(XYWH=region_XYWH, rescale=rescale)
555
+ img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
556
+
557
+ region_modalities, region_scoring = infer_modalities(img, tile_size, model_dir, color_dapi=color_dapi, color_marker=color_marker, return_seg_intermediate=seg_intermediate, seg_only=seg_only)
558
+ if region_scoring is not None:
559
+ if scoring is None:
560
+ scoring = {
561
+ 'num_pos': region_scoring['num_pos'],
562
+ 'num_neg': region_scoring['num_neg'],
563
+ }
564
+ else:
565
+ scoring['num_pos'] += region_scoring['num_pos']
566
+ scoring['num_neg'] += region_scoring['num_neg']
567
+
568
+ for name, img in region_modalities.items():
569
+ if name not in results:
570
+ results[name] = np.zeros((size_y, size_x, 3), dtype=np.uint8)
571
+ results[name][region_XYWH[1]: region_XYWH[1] + region_XYWH[3],
572
+ region_XYWH[0]: region_XYWH[0] + region_XYWH[2]] = np.array(img)
573
+ start_y += region_size
574
+ start_y = 0
575
+ start_x += region_size
576
+
577
+ # write_results_to_pickle_file(os.path.join(results_dir, "results.pickle"), results)
660
578
  # read_results_from_pickle_file(os.path.join(results_dir, "results.pickle"))
661
579
 
662
580
  for name, img in results.items():
663
- write_big_tiff_file(os.path.join(results_dir, filename.replace('.svs', '_' + name + '.ome.tiff')), img,
664
- tile_size)
581
+ write_big_tiff_file(os.path.join(results_dir, f'{basename}_{name}.ome.tiff'), img, tile_size)
582
+
583
+ if scoring is not None:
584
+ scoring['num_total'] = scoring['num_pos'] + scoring['num_neg']
585
+ scoring['percent_pos'] = round(scoring['num_pos'] / scoring['num_total'] * 100, 1) if scoring['num_pos'] > 0 else 0
586
+ with open(os.path.join(results_dir, f'{basename}.json'), 'w') as f:
587
+ json.dump(scoring, f, indent=2)
665
588
 
666
589
  javabridge.kill_vm()
590
+
591
+
592
+ def get_wsi_resolution(filename):
593
+ """
594
+ Try to get the resolution (magnification) of the slide and
595
+ the corresponding tile size to use by default for DeepLIIF.
596
+ If it cannot be found, return (None, None) instead.
597
+
598
+ Note: This will start the javabridge VM, but not kill it.
599
+ It must be killed elsewhere.
600
+
601
+ Parameters
602
+ ----------
603
+ filename : str
604
+ Full path to the file.
605
+
606
+ Returns
607
+ -------
608
+ str :
609
+ Magnification (objective power) from image metadata.
610
+ int :
611
+ Corresponding tile size for DeepLIIF.
612
+ """
613
+
614
+ # make sure javabridge is already set up from with call to get_information()
615
+ size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
616
+
617
+ mag = None
618
+ metadata = bioformats.get_omexml_metadata(filename)
619
+ try:
620
+ omexml = bioformats.OMEXML(metadata)
621
+ mag = omexml.instrument().Objective.NominalMagnification
622
+ except Exception as e:
623
+ fields = ['AppMag', 'NominalMagnification']
624
+ try:
625
+ for field in fields:
626
+ idx = metadata.find(field)
627
+ if idx >= 0:
628
+ for i in range(idx, len(metadata)):
629
+ if metadata[i].isdigit() or metadata[i] == '.':
630
+ break
631
+ for j in range(i, len(metadata)):
632
+ if not metadata[j].isdigit() and metadata[j] != '.':
633
+ break
634
+ if i == j:
635
+ continue
636
+ mag = metadata[i:j]
637
+ break
638
+ except Exception as e:
639
+ pass
640
+
641
+ if mag is None:
642
+ return None, None
643
+
644
+ try:
645
+ tile_size = round((float(mag) / 40) * 512)
646
+ return mag, tile_size
647
+ except Exception as e:
648
+ return None, None
649
+
650
+
651
+ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False):
652
+ """
653
+ Perform inference on a slide and get the results individual cell data.
654
+
655
+ Parameters
656
+ ----------
657
+ filename : str
658
+ Full path to the file.
659
+ model_dir : str
660
+ Full path to the directory with the DeepLIIF model files.
661
+ tile_size : int
662
+ Size of tiles to extract and perform inference on.
663
+ region_size : int
664
+ Maximum size to split the slide for processing.
665
+ version : int
666
+ Version of cell data to return (3 or 4).
667
+ print_log : bool
668
+ Whether or not to print updates while processing.
669
+
670
+ Returns
671
+ -------
672
+ dict :
673
+ Individual cell data and associated values.
674
+ """
675
+
676
+ def print_info(*args):
677
+ if print_log:
678
+ print(*args, flush=True)
679
+
680
+ resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
681
+
682
+ size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
683
+ rescale = (pixel_type != 'uint8')
684
+ print_info('Info:', size_x, size_y, size_z, size_c, size_t, pixel_type)
685
+
686
+ num_regions_x = math.ceil(size_x / region_size)
687
+ num_regions_y = math.ceil(size_y / region_size)
688
+ stride_x = math.ceil(size_x / num_regions_x)
689
+ stride_y = math.ceil(size_y / num_regions_y)
690
+ print_info('Strides:', stride_x, stride_y)
691
+
692
+ data = None
693
+ default_marker_thresh, count_marker_thresh = 0, 0
694
+ default_size_thresh, count_size_thresh = 0, 0
695
+
696
+ # javabridge already set up from previous call to get_information()
697
+ with bioformats.ImageReader(filename) as reader:
698
+ start_x, start_y = 0, 0
699
+
700
+ while start_y < size_y:
701
+ while start_x < size_x:
702
+ region_XYWH = (start_x, start_y, min(stride_x, size_x-start_x), min(stride_y, size_y-start_y))
703
+ print_info('Region:', region_XYWH)
704
+
705
+ region = reader.read(XYWH=region_XYWH, rescale=rescale)
706
+ print_info(region.shape, region.dtype)
707
+ img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
708
+ print_info(img.size, img.mode)
709
+
710
+ images = inference(
711
+ img,
712
+ tile_size=tile_size,
713
+ overlap_size=tile_size//16,
714
+ model_path=model_dir,
715
+ eager_mode=False,
716
+ color_dapi=False,
717
+ color_marker=False,
718
+ opt=None,
719
+ return_seg_intermediate=False,
720
+ seg_only=True,
721
+ )
722
+ region_data = compute_cell_results(images['Seg'], images.get('Marker'), resolution, version=version)
723
+
724
+ if start_x != 0 or start_y != 0:
725
+ for i in range(len(region_data['cells'])):
726
+ cell = decode_cell_data_v4(region_data['cells'][i]) if version == 4 else region_data['cells'][i]
727
+ for j in range(2):
728
+ cell['bbox'][j] = (cell['bbox'][j][0] + start_x, cell['bbox'][j][1] + start_y)
729
+ cell['centroid'] = (cell['centroid'][0] + start_x, cell['centroid'][1] + start_y)
730
+ for j in range(len(cell['boundary'])):
731
+ cell['boundary'][j] = (cell['boundary'][j][0] + start_x, cell['boundary'][j][1] + start_y)
732
+ region_data['cells'][i] = encode_cell_data_v4(cell) if version == 4 else cell
733
+
734
+ if data is None:
735
+ data = region_data
736
+ else:
737
+ data['cells'] += region_data['cells']
738
+
739
+ if region_data['settings']['default_marker_thresh'] is not None and region_data['settings']['default_marker_thresh'] != 0:
740
+ default_marker_thresh += region_data['settings']['default_marker_thresh']
741
+ count_marker_thresh += 1
742
+ if region_data['settings']['default_size_thresh'] != 0:
743
+ default_size_thresh += region_data['settings']['default_size_thresh']
744
+ count_size_thresh += 1
745
+
746
+ start_x += stride_x
747
+
748
+ start_x = 0
749
+ start_y += stride_y
750
+
751
+ javabridge.kill_vm()
752
+
753
+ if count_marker_thresh == 0:
754
+ count_marker_thresh = 1
755
+ if count_size_thresh == 0:
756
+ count_size_thresh = 1
757
+ data['settings']['default_marker_thresh'] = round(default_marker_thresh / count_marker_thresh)
758
+ data['settings']['default_size_thresh'] = round(default_size_thresh / count_size_thresh)
759
+
760
+ return data