deepliif 1.1.11__py3-none-any.whl → 1.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,7 +12,7 @@ In the function <__init__>, you need to define four lists:
12
12
  -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
13
  -- self.model_names (str list): define networks used in our training.
14
14
  -- self.visual_names (str list): specify the images that you want to display and save.
15
- -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See CycleGAN_model.py for an usage.
16
16
 
17
17
  Now you can use the model class by specifying flag '--model dummy'.
18
18
  See our template model class 'template_model.py' for more details.
@@ -23,17 +23,23 @@ import itertools
23
23
  import importlib
24
24
  from functools import lru_cache
25
25
  from io import BytesIO
26
+ import json
27
+ import math
26
28
 
27
29
  import requests
28
30
  import torch
29
31
  from PIL import Image
32
+ Image.MAX_IMAGE_PIXELS = None
33
+
30
34
  import numpy as np
31
35
  from dask import delayed, compute
36
+ import openslide
32
37
 
33
38
  from deepliif.util import *
34
- from deepliif.util.util import tensor_to_pil, check_multi_scale
39
+ from deepliif.util.util import tensor_to_pil
35
40
  from deepliif.data import transform
36
- from deepliif.postprocessing import compute_results
41
+ from deepliif.postprocessing import compute_final_results, compute_cell_results
42
+ from deepliif.postprocessing import encode_cell_data_v4, decode_cell_data_v4
37
43
  from deepliif.options import Options, print_options
38
44
 
39
45
  from .base_model import BaseModel
@@ -115,14 +121,19 @@ def load_torchscript_model(model_pt_path, device):
115
121
 
116
122
 
117
123
 
118
- def load_eager_models(opt, devices):
124
+ def load_eager_models(opt, devices=None):
119
125
  # create a model given model and other options
120
126
  model = create_model(opt)
121
127
  # regular setup: load and print networks; create schedulers
122
128
  model.setup(opt)
123
129
 
124
130
  nets = {}
125
- for name in model.model_names:
131
+ if devices:
132
+ model_names = list(devices.keys())
133
+ else:
134
+ model_names = model.model_names
135
+
136
+ for name in model_names:#model.model_names:
126
137
  if isinstance(name, str):
127
138
  if '_' in name:
128
139
  net = getattr(model, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
@@ -138,7 +149,8 @@ def load_eager_models(opt, devices):
138
149
  net = net.module
139
150
 
140
151
  nets[name] = net
141
- nets[name].to(devices[name])
152
+ if devices:
153
+ nets[name].to(devices[name])
142
154
 
143
155
  return nets
144
156
 
@@ -154,8 +166,7 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
154
166
  """
155
167
  if opt is None:
156
168
  opt = get_opt(model_dir, mode=phase)
157
- opt.use_dp = False
158
- #print_options(opt)
169
+ opt.use_dp = False
159
170
 
160
171
  if opt.model == 'DeepLIIF':
161
172
  net_groups = [
@@ -170,12 +181,16 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
170
181
  net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
171
182
  else:
172
183
  net_groups = [(f'G_{i+1}',) for i in range(opt.modalities_no)]
184
+ elif opt.model == 'CycleGAN':
185
+ if opt.BtoA:
186
+ net_groups = [(f'GB_{i+1}',) for i in range(opt.modalities_no)]
187
+ else:
188
+ net_groups = [(f'GA_{i+1}',) for i in range(opt.modalities_no)]
173
189
  else:
174
190
  raise Exception(f'init_nets() not implemented for model {opt.model}')
175
191
 
176
192
  number_of_gpus_all = torch.cuda.device_count()
177
- number_of_gpus = len(opt.gpu_ids)
178
- #print(number_of_gpus)
193
+ number_of_gpus = min(len(opt.gpu_ids),number_of_gpus_all)
179
194
 
180
195
  if number_of_gpus > 0:
181
196
  mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
@@ -203,12 +218,13 @@ def compute_overlap(img_size, tile_size):
203
218
  return tile_size // 4
204
219
 
205
220
 
206
- def run_torchserve(img, model_path=None, eager_mode=False, opt=None):
221
+ def run_torchserve(img, model_path=None, eager_mode=False, opt=None, seg_only=False):
207
222
  """
208
223
  eager_mode: not used in this function; put in place to be consistent with run_dask
209
224
  so that run_wrapper() could call either this function or run_dask with
210
225
  same syntax
211
226
  opt: same as eager_mode
227
+ seg_only: same as eager_mode
212
228
  """
213
229
  buffer = BytesIO()
214
230
  torch.save(transform(img.resize((opt.scale_size, opt.scale_size))), buffer)
@@ -227,9 +243,10 @@ def run_torchserve(img, model_path=None, eager_mode=False, opt=None):
227
243
  return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
228
244
 
229
245
 
230
- def run_dask(img, model_path, eager_mode=False, opt=None):
246
+ def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
231
247
  model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
232
248
  nets = init_nets(model_dir, eager_mode, opt)
249
+ use_dask = True if opt.norm != 'spectral' else False
233
250
 
234
251
  if opt.input_no > 1 or opt.model == 'SDG':
235
252
  l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
@@ -238,45 +255,69 @@ def run_dask(img, model_path, eager_mode=False, opt=None):
238
255
  ts = transform(img.resize((opt.scale_size, opt.scale_size)))
239
256
 
240
257
 
241
- @delayed
242
- def forward(input, model):
243
- with torch.no_grad():
244
- return model(input.to(next(model.parameters()).device))
258
+ if use_dask:
259
+ @delayed
260
+ def forward(input, model):
261
+ with torch.no_grad():
262
+ return model(input.to(next(model.parameters()).device))
263
+ else: # some train settings like spectral norm some how in inference mode is not compatible with dask
264
+ def forward(input, model):
265
+ with torch.no_grad():
266
+ return model(input.to(next(model.parameters()).device))
245
267
 
246
268
  if opt.model == 'DeepLIIF':
269
+ weights = {
270
+ 'G51': 0.25, # IHC
271
+ 'G52': 0.25, # Hema
272
+ 'G53': 0.25, # DAPI
273
+ 'G54': 0.00, # Lap2
274
+ 'G55': 0.25, # Marker
275
+ }
276
+
247
277
  seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
278
+ if seg_only:
279
+ seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0}
248
280
 
249
281
  lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
282
+ if 'G4' not in seg_map:
283
+ lazy_gens['G4'] = forward(ts, nets['G4'])
250
284
  gens = compute(lazy_gens)[0]
251
285
 
252
286
  lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
253
- lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
287
+ if not seg_only or weights['G51'] != 0:
288
+ lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
254
289
  segs = compute(lazy_segs)[0]
255
290
 
256
- weights = {
257
- 'G51': 0.25, # IHC
258
- 'G52': 0.25, # Hema
259
- 'G53': 0.25, # DAPI
260
- 'G54': 0.00, # Lap2
261
- 'G55': 0.25, # Marker
262
- }
263
291
  seg = torch.stack([torch.mul(segs[k], weights[k]) for k in segs.keys()]).sum(dim=0)
264
292
 
265
- res = {k: tensor_to_pil(v) for k, v in gens.items()}
293
+ if seg_only:
294
+ res = {'G4': tensor_to_pil(gens['G4'])} if 'G4' in gens else {}
295
+ else:
296
+ res = {k: tensor_to_pil(v) for k, v in gens.items()}
297
+ res.update({k: tensor_to_pil(v) for k, v in segs.items()})
266
298
  res['G5'] = tensor_to_pil(seg)
267
299
 
268
300
  return res
269
- elif opt.model in ['DeepLIIFExt','SDG']:
270
- seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
301
+ elif opt.model in ['DeepLIIFExt','SDG','CycleGAN']:
302
+ if opt.model == 'CycleGAN':
303
+ seg_map = {f'GB_{i+1}':None for i in range(opt.modalities_no)} if opt.BtoA else {f'GA_{i+1}':None for i in range(opt.modalities_no)}
304
+ else:
305
+ seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
271
306
 
272
- lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
273
- gens = compute(lazy_gens)[0]
307
+ if use_dask:
308
+ lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
309
+ gens = compute(lazy_gens)[0]
310
+ else:
311
+ gens = {k: forward(ts, nets[k]) for k in seg_map}
274
312
 
275
313
  res = {k: tensor_to_pil(v) for k, v in gens.items()}
276
314
 
277
315
  if opt.seg_gen:
278
- lazy_segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
279
- segs = compute(lazy_segs)[0]
316
+ if use_dask:
317
+ lazy_segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
318
+ segs = compute(lazy_segs)[0]
319
+ else:
320
+ segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
280
321
  res.update({k: tensor_to_pil(v) for k, v in segs.items()})
281
322
 
282
323
  return res
@@ -284,14 +325,6 @@ def run_dask(img, model_path, eager_mode=False, opt=None):
284
325
  raise Exception(f'run_dask() not fully implemented for {opt.model}')
285
326
 
286
327
 
287
- def is_empty_old(tile):
288
- # return True if np.mean(np.array(tile) - np.array(mean_background_val)) < 40 else False
289
- if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
290
- return all([True if calculate_background_area(t) > 98 else False for t in tile])
291
- else:
292
- return True if calculate_background_area(tile) > 98 else False
293
-
294
-
295
328
  def is_empty(tile):
296
329
  thresh = 15
297
330
  if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
@@ -300,18 +333,29 @@ def is_empty(tile):
300
333
  return True if np.max(image_variance_rgb(tile)) < thresh else False
301
334
 
302
335
 
303
- def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None):
336
+ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None, seg_only=False):
304
337
  if opt.model == 'DeepLIIF':
305
338
  if is_empty(tile):
306
- return {
307
- 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
308
- 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
309
- 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
310
- 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
311
- 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0))
312
- }
339
+ if seg_only:
340
+ return {
341
+ 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
342
+ 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
343
+ }
344
+ else :
345
+ return {
346
+ 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
347
+ 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
348
+ 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
349
+ 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
350
+ 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
351
+ 'G51': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
352
+ 'G52': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
353
+ 'G53': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
354
+ 'G54': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
355
+ 'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
356
+ }
313
357
  else:
314
- return run_fn(tile, model_path, eager_mode, opt)
358
+ return run_fn(tile, model_path, eager_mode, opt, seg_only)
315
359
  elif opt.model in ['DeepLIIFExt', 'SDG']:
316
360
  if is_empty(tile):
317
361
  res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
@@ -319,178 +363,20 @@ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None):
319
363
  return res
320
364
  else:
321
365
  return run_fn(tile, model_path, eager_mode, opt)
366
+ elif opt.model in ['CycleGAN']:
367
+ if is_empty(tile):
368
+ net_names = ['GB_{i+1}' for i in range(opt.modalities_no)] if opt.BtoA else [f'GA_{i+1}' for i in range(opt.modalities_no)]
369
+ res = {net_name: Image.new(mode='RGB', size=(512, 512)) for net_name in net_names}
370
+ return res
371
+ else:
372
+ return run_fn(tile, model_path, eager_mode, opt)
322
373
  else:
323
374
  raise Exception(f'run_wrapper() not implemented for model {opt.model}')
324
375
 
325
376
 
326
- def inference_old(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
327
- color_dapi=False, color_marker=False):
328
-
329
- tiles = list(generate_tiles(img, tile_size, overlap_size))
330
-
331
- run_fn = run_torchserve if use_torchserve else run_dask
332
- # res = [Tile(t.i, t.j, run_fn(t.img, model_path)) for t in tiles]
333
- res = [Tile(t.i, t.j, run_wrapper(t.img, run_fn, model_path, eager_mode)) for t in tiles]
334
-
335
- def get_net_tiles(n):
336
- return [Tile(t.i, t.j, t.img[n]) for t in res]
337
-
338
- images = {}
339
-
340
- images['Hema'] = stitch(get_net_tiles('G1'), tile_size, overlap_size).resize(img.size)
341
-
342
- # images['DAPI'] = stitch(
343
- # [Tile(t.i, t.j, adjust_background_tile(dt.img))
344
- # for t, dt in zip(tiles, get_net_tiles('G2'))],
345
- # tile_size, overlap_size).resize(img.size)
346
- # dapi_pix = np.array(images['DAPI'])
347
- # dapi_pix[:, :, 0] = 0
348
- # images['DAPI'] = Image.fromarray(dapi_pix)
349
-
350
- images['DAPI'] = stitch(get_net_tiles('G2'), tile_size, overlap_size).resize(img.size)
351
- dapi_pix = np.array(images['DAPI'].convert('L').convert('RGB'))
352
- if color_dapi:
353
- dapi_pix[:, :, 0] = 0
354
- images['DAPI'] = Image.fromarray(dapi_pix)
355
- images['Lap2'] = stitch(get_net_tiles('G3'), tile_size, overlap_size).resize(img.size)
356
- images['Marker'] = stitch(get_net_tiles('G4'), tile_size, overlap_size).resize(img.size)
357
- marker_pix = np.array(images['Marker'].convert('L').convert('RGB'))
358
- if color_marker:
359
- marker_pix[:, :, 2] = 0
360
- images['Marker'] = Image.fromarray(marker_pix)
361
-
362
- # images['Marker'] = stitch(
363
- # [Tile(t.i, t.j, kt.img)
364
- # for t, kt in zip(tiles, get_net_tiles('G4'))],
365
- # tile_size, overlap_size).resize(img.size)
366
-
367
- images['Seg'] = stitch(get_net_tiles('G5'), tile_size, overlap_size).resize(img.size)
368
-
369
- return images
370
-
371
-
372
- def inference_old2(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
373
- color_dapi=False, color_marker=False, opt=None):
374
- if not opt:
375
- opt = get_opt(model_path)
376
- #print_options(opt)
377
-
378
- if opt.model == 'DeepLIIF':
379
- rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
380
-
381
- run_fn = run_torchserve if use_torchserve else run_dask
382
-
383
- images = {}
384
- d_modality2net = {'Hema':'G1',
385
- 'DAPI':'G2',
386
- 'Lap2':'G3',
387
- 'Marker':'G4',
388
- 'Seg':'G5'}
389
-
390
- for k in d_modality2net.keys():
391
- images[k] = create_image_for_stitching(tile_size, rows, cols)
392
-
393
- for i in range(cols):
394
- for j in range(rows):
395
- tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
396
- res = run_wrapper(tile, run_fn, model_path, eager_mode, opt)
397
-
398
- for modality_name, net_name in d_modality2net.items():
399
- stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
400
-
401
- for modality_name, output_img in images.items():
402
- images[modality_name] = output_img.resize(img.size)
403
-
404
- if color_dapi:
405
- matrix = ( 0, 0, 0, 0,
406
- 299/1000, 587/1000, 114/1000, 0,
407
- 299/1000, 587/1000, 114/1000, 0)
408
- images['DAPI'] = images['DAPI'].convert('RGB', matrix)
409
-
410
- if color_marker:
411
- matrix = (299/1000, 587/1000, 114/1000, 0,
412
- 299/1000, 587/1000, 114/1000, 0,
413
- 0, 0, 0, 0)
414
- images['Marker'] = images['Marker'].convert('RGB', matrix)
415
-
416
- return images
417
-
418
- elif opt.model == 'DeepLIIFExt':
419
- #param_dict = read_train_options(model_path)
420
- #modalities_no = int(param_dict['modalities_no']) if param_dict else 4
421
- #seg_gen = (param_dict['seg_gen'] == 'True') if param_dict else True
422
-
423
-
424
- rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
425
- run_fn = run_torchserve if use_torchserve else run_dask
426
-
427
- def get_net_tiles(n):
428
- return [Tile(t.i, t.j, t.img[n]) for t in res]
429
-
430
- images = {}
431
- d_modality2net = {f'mod{i}':f'G_{i}' for i in range(1, opt.modalities_no + 1)}
432
- if opt.seg_gen:
433
- d_modality2net.update({f'Seg{i}':f'GS_{i}' for i in range(1, opt.modalities_no + 1)})
434
-
435
- for k in d_modality2net.keys():
436
- images[k] = create_image_for_stitching(tile_size, rows, cols)
437
-
438
- for i in range(cols):
439
- for j in range(rows):
440
- tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
441
- res = run_wrapper(tile, run_fn, model_path, eager_mode, opt)
442
-
443
- for modality_name, net_name in d_modality2net.items():
444
- stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
445
-
446
- for modality_name, output_img in images.items():
447
- images[modality_name] = output_img.resize(img.size)
448
-
449
- return images
450
-
451
- elif opt.model == 'SDG':
452
- # SDG could have multiple input images / modalities
453
- # the input hence could be a rectangle
454
- # we split the input to get each modality image one by one
455
- # then create tiles for each of the modality images
456
- # tile_pair is a list that contains the tiles at the given location for each modality image
457
- # l_tile_pair is a list of tile_pair that covers all locations
458
- # for inference, each tile_pair is used to get the output at the given location
459
- w, h = img.size
460
- w2 = int(w / opt.input_no)
461
-
462
- l_img = []
463
- for i in range(opt.input_no):
464
- img_i = img.crop((w2 * i, 0, w2 * (i+1), h))
465
- rescaled_img_i, rows, cols = format_image_for_tiling(img_i, tile_size, overlap_size)
466
- l_img.append(rescaled_img_i)
467
-
468
- run_fn = run_torchserve if use_torchserve else run_dask
469
-
470
- images = {}
471
- d_modality2net = {f'mod{i}':f'G_{i}' for i in range(1, opt.modalities_no + 1)}
472
- for k in d_modality2net.keys():
473
- images[k] = create_image_for_stitching(tile_size, rows, cols)
474
-
475
- for i in range(cols):
476
- for j in range(rows):
477
- tile_pair = [extract_tile(rescaled, tile_size, overlap_size, i, j) for rescaled in l_img]
478
- res = run_wrapper(tile_pair, run_fn, model_path, eager_mode, opt)
479
-
480
- for modality_name, net_name in d_modality2net.items():
481
- stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
482
-
483
- for modality_name, output_img in images.items():
484
- images[modality_name] = output_img.resize((w2,w2))
485
-
486
- return images
487
-
488
- else:
489
- raise Exception(f'inference() not implemented for model {opt.model}')
490
-
491
-
492
377
  def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
493
- eager_mode=False, color_dapi=False, color_marker=False, opt=None):
378
+ eager_mode=False, color_dapi=False, color_marker=False, opt=None,
379
+ return_seg_intermediate=False, seg_only=False):
494
380
  if not opt:
495
381
  opt = get_opt(model_path)
496
382
  #print_options(opt)
@@ -508,23 +394,36 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
508
394
 
509
395
  tiler = InferenceTiler(orig, tile_size, overlap_size)
510
396
  for tile in tiler:
511
- tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt))
397
+ tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt, seg_only))
512
398
  results = tiler.results()
513
399
 
514
400
  if opt.model == 'DeepLIIF':
515
- images = {
516
- 'Hema': results['G1'],
517
- 'DAPI': results['G2'],
518
- 'Lap2': results['G3'],
519
- 'Marker': results['G4'],
520
- 'Seg': results['G5'],
521
- }
522
- if color_dapi:
401
+ if seg_only:
402
+ images = {'Seg': results['G5']}
403
+ if 'G4' in results:
404
+ images.update({'Marker': results['G4']})
405
+ else:
406
+ images = {
407
+ 'Hema': results['G1'],
408
+ 'DAPI': results['G2'],
409
+ 'Lap2': results['G3'],
410
+ 'Marker': results['G4'],
411
+ 'Seg': results['G5'],
412
+ }
413
+
414
+ if return_seg_intermediate and not seg_only:
415
+ images.update({'IHC_s':results['G51'],
416
+ 'Hema_s':results['G52'],
417
+ 'DAPI_s':results['G53'],
418
+ 'Lap2_s':results['G54'],
419
+ 'Marker_s':results['G55'],})
420
+
421
+ if color_dapi and not seg_only:
523
422
  matrix = ( 0, 0, 0, 0,
524
423
  299/1000, 587/1000, 114/1000, 0,
525
424
  299/1000, 587/1000, 114/1000, 0)
526
425
  images['DAPI'] = images['DAPI'].convert('RGB', matrix)
527
- if color_marker:
426
+ if color_marker and not seg_only:
528
427
  matrix = (299/1000, 587/1000, 114/1000, 0,
529
428
  299/1000, 587/1000, 114/1000, 0,
530
429
  0, 0, 0, 0)
@@ -546,27 +445,27 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
546
445
  return results # return result images with default key names (i.e., net names)
547
446
 
548
447
 
549
- def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='auto', marker_thresh='auto', size_thresh_upper=None):
448
+ def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='default', marker_thresh=None, size_thresh_upper=None):
550
449
  if model == 'DeepLIIF':
551
450
  resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
552
- overlay, refined, scoring = compute_results(np.array(orig), np.array(images['Seg']),
553
- np.array(images['Marker'].convert('L')) if 'Marker' in images else None,
554
- resolution, seg_thresh, size_thresh, marker_thresh, size_thresh_upper)
451
+ overlay, refined, scoring = compute_final_results(
452
+ orig, images['Seg'], images.get('Marker'), resolution,
453
+ size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
555
454
  processed_images = {}
556
455
  processed_images['SegOverlaid'] = Image.fromarray(overlay)
557
456
  processed_images['SegRefined'] = Image.fromarray(refined)
558
457
  return processed_images, scoring
559
458
 
560
- elif model == 'DeepLIIFExt':
459
+ elif model in ['DeepLIIFExt','SDG']:
561
460
  resolution = '40x' if tile_size > 768 else ('20x' if tile_size > 384 else '10x')
562
461
  processed_images = {}
563
462
  scoring = {}
564
463
  for img_name in list(images.keys()):
565
464
  if 'Seg' in img_name:
566
465
  seg_img = images[img_name]
567
- overlay, refined, score = compute_results(np.array(orig), np.array(images[img_name]),
568
- None, resolution,
569
- seg_thresh, size_thresh, marker_thresh, size_thresh_upper)
466
+ overlay, refined, score = compute_final_results(
467
+ orig, images[img_name], None, resolution,
468
+ size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
570
469
 
571
470
  processed_images[img_name + '_Overlaid'] = Image.fromarray(overlay)
572
471
  processed_images[img_name + '_Refined'] = Image.fromarray(refined)
@@ -578,7 +477,8 @@ def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='aut
578
477
 
579
478
 
580
479
  def infer_modalities(img, tile_size, model_dir, eager_mode=False,
581
- color_dapi=False, color_marker=False, opt=None):
480
+ color_dapi=False, color_marker=False, opt=None,
481
+ return_seg_intermediate=False, seg_only=False):
582
482
  """
583
483
  This function is used to infer modalities for the given image using a trained model.
584
484
  :param img: The input image.
@@ -591,11 +491,6 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
591
491
  opt.use_dp = False
592
492
  #print_options(opt)
593
493
 
594
- if not tile_size:
595
- tile_size = check_multi_scale(Image.open('./images/target.png').convert('L'),
596
- img.convert('L'))
597
- tile_size = int(tile_size)
598
-
599
494
  # for those with multiple input modalities, find the correct size to calculate overlap_size
600
495
  input_no = opt.input_no if hasattr(opt, 'input_no') else 1
601
496
  img_size = (img.size[0] / input_no, img.size[1]) # (width, height)
@@ -609,18 +504,24 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
609
504
  eager_mode=eager_mode,
610
505
  color_dapi=color_dapi,
611
506
  color_marker=color_marker,
612
- opt=opt
507
+ opt=opt,
508
+ return_seg_intermediate=return_seg_intermediate,
509
+ seg_only=seg_only
613
510
  )
614
-
511
+
615
512
  if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
616
513
  post_images, scoring = postprocess(img, images, tile_size, opt.model)
617
514
  images = {**images, **post_images}
515
+ if seg_only:
516
+ delete_keys = [k for k in images.keys() if 'Seg' not in k]
517
+ for name in delete_keys:
518
+ del images[name]
618
519
  return images, scoring
619
520
  else:
620
521
  return images, None
621
522
 
622
523
 
623
- def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000):
524
+ def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000, color_dapi=False, color_marker=False, seg_intermediate=False, seg_only=False):
624
525
  """
625
526
  This function infers modalities and segmentation mask for the given WSI image. It
626
527
 
@@ -632,35 +533,197 @@ def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size,
632
533
  :param region_size: The size of each individual region to be processed at once.
633
534
  :return:
634
535
  """
635
- results_dir = os.path.join(output_dir, filename)
536
+ basename, _ = os.path.splitext(filename)
537
+ results_dir = os.path.join(output_dir, basename)
636
538
  if not os.path.exists(results_dir):
637
539
  os.makedirs(results_dir)
638
540
  size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(os.path.join(input_dir, filename))
639
- print(filename, size_x, size_y, size_z, size_c, size_t, pixel_type)
541
+ rescale = (pixel_type != 'uint8')
542
+ print(filename, size_x, size_y, size_z, size_c, size_t, pixel_type, flush=True)
543
+
640
544
  results = {}
641
- start_x, start_y = 0, 0
642
- while start_x < size_x:
643
- while start_y < size_y:
644
- print(start_x, start_y)
645
- region_XYWH = (start_x, start_y, min(region_size, size_x - start_x), min(region_size, size_y - start_y))
646
- region = read_bioformats_image_with_reader(os.path.join(input_dir, filename), region=region_XYWH)
647
-
648
- region_modalities, region_scoring = infer_modalities(Image.fromarray((region * 255).astype(np.uint8)), tile_size, model_dir)
649
-
650
- for name, img in region_modalities.items():
651
- if name not in results:
652
- results[name] = np.zeros((size_y, size_x, 3), dtype=np.uint8)
653
- results[name][region_XYWH[1]: region_XYWH[1] + region_XYWH[3],
654
- region_XYWH[0]: region_XYWH[0] + region_XYWH[2]] = np.array(img)
655
- start_y += region_size
656
- start_y = 0
657
- start_x += region_size
658
-
659
- write_results_to_pickle_file(os.path.join(results_dir, "results.pickle"), results)
545
+ scoring = None
546
+
547
+ # javabridge already set up from previous call to get_information()
548
+ with bioformats.ImageReader(os.path.join(input_dir, filename)) as reader:
549
+ start_x, start_y = 0, 0
550
+
551
+ while start_x < size_x:
552
+ while start_y < size_y:
553
+ print(start_x, start_y, flush=True)
554
+ region_XYWH = (start_x, start_y, min(region_size, size_x - start_x), min(region_size, size_y - start_y))
555
+ region = reader.read(XYWH=region_XYWH, rescale=rescale)
556
+ img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
557
+
558
+ region_modalities, region_scoring = infer_modalities(img, tile_size, model_dir, color_dapi=color_dapi, color_marker=color_marker, return_seg_intermediate=seg_intermediate, seg_only=seg_only)
559
+ if region_scoring is not None:
560
+ if scoring is None:
561
+ scoring = {
562
+ 'num_pos': region_scoring['num_pos'],
563
+ 'num_neg': region_scoring['num_neg'],
564
+ }
565
+ else:
566
+ scoring['num_pos'] += region_scoring['num_pos']
567
+ scoring['num_neg'] += region_scoring['num_neg']
568
+
569
+ for name, img in region_modalities.items():
570
+ if name not in results:
571
+ results[name] = np.zeros((size_y, size_x, 3), dtype=np.uint8)
572
+ results[name][region_XYWH[1]: region_XYWH[1] + region_XYWH[3],
573
+ region_XYWH[0]: region_XYWH[0] + region_XYWH[2]] = np.array(img)
574
+ start_y += region_size
575
+ start_y = 0
576
+ start_x += region_size
577
+
578
+ # write_results_to_pickle_file(os.path.join(results_dir, "results.pickle"), results)
660
579
  # read_results_from_pickle_file(os.path.join(results_dir, "results.pickle"))
661
580
 
662
581
  for name, img in results.items():
663
- write_big_tiff_file(os.path.join(results_dir, filename.replace('.svs', '_' + name + '.ome.tiff')), img,
664
- tile_size)
582
+ write_big_tiff_file(os.path.join(results_dir, f'{basename}_{name}.ome.tiff'), img, tile_size)
583
+
584
+ if scoring is not None:
585
+ scoring['num_total'] = scoring['num_pos'] + scoring['num_neg']
586
+ scoring['percent_pos'] = round(scoring['num_pos'] / scoring['num_total'] * 100, 1) if scoring['num_pos'] > 0 else 0
587
+ with open(os.path.join(results_dir, f'{basename}.json'), 'w') as f:
588
+ json.dump(scoring, f, indent=2)
589
+
590
+ javabridge.kill_vm()
591
+
592
+
593
+ def get_wsi_resolution(filename):
594
+ """
595
+ Use OpenSlide to get the resolution (magnification) of the slide
596
+ and the corresponding tile size to use by default for DeepLIIF.
597
+ If it cannot be found, return (None, None) instead.
598
+
599
+ Parameters
600
+ ----------
601
+ filename : str
602
+ Full path to the file.
603
+
604
+ Returns
605
+ -------
606
+ str :
607
+ Magnification (objective power) as found by OpenSlide.
608
+ int :
609
+ Corresponding tile size for DeepLIIF.
610
+ """
611
+ try:
612
+ image = openslide.OpenSlide(filename)
613
+ mag = image.properties.get(openslide.PROPERTY_NAME_OBJECTIVE_POWER)
614
+ tile_size = round((float(mag) / 40) * 512)
615
+ return mag, tile_size
616
+ except Exception as e:
617
+ return None, None
618
+
619
+
620
+ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False):
621
+ """
622
+ Perform inference on a slide and get the results individual cell data.
623
+
624
+ Parameters
625
+ ----------
626
+ filename : str
627
+ Full path to the file.
628
+ model_dir : str
629
+ Full path to the directory with the DeepLIIF model files.
630
+ tile_size : int
631
+ Size of tiles to extract and perform inference on.
632
+ region_size : int
633
+ Maximum size to split the slide for processing.
634
+ version : int
635
+ Version of cell data to return (3 or 4).
636
+ print_log : bool
637
+ Whether or not to print updates while processing.
638
+
639
+ Returns
640
+ -------
641
+ dict :
642
+ Individual cell data and associated values.
643
+ """
644
+
645
+ def print_info(*args):
646
+ if print_log:
647
+ print(*args, flush=True)
648
+
649
+ resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
650
+
651
+ size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
652
+ rescale = (pixel_type != 'uint8')
653
+ print_info('Info:', size_x, size_y, size_z, size_c, size_t, pixel_type)
654
+
655
+ num_regions_x = math.ceil(size_x / region_size)
656
+ num_regions_y = math.ceil(size_y / region_size)
657
+ stride_x = math.ceil(size_x / num_regions_x)
658
+ stride_y = math.ceil(size_y / num_regions_y)
659
+ print_info('Strides:', stride_x, stride_y)
660
+
661
+ data = None
662
+ default_marker_thresh, count_marker_thresh = 0, 0
663
+ default_size_thresh, count_size_thresh = 0, 0
664
+
665
+ # javabridge already set up from previous call to get_information()
666
+ with bioformats.ImageReader(filename) as reader:
667
+ start_x, start_y = 0, 0
668
+
669
+ while start_y < size_y:
670
+ while start_x < size_x:
671
+ region_XYWH = (start_x, start_y, min(stride_x, size_x-start_x), min(stride_y, size_y-start_y))
672
+ print_info('Region:', region_XYWH)
673
+
674
+ region = reader.read(XYWH=region_XYWH, rescale=rescale)
675
+ print_info(region.shape, region.dtype)
676
+ img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
677
+ print_info(img.size, img.mode)
678
+
679
+ images = inference(
680
+ img,
681
+ tile_size=tile_size,
682
+ overlap_size=tile_size//16,
683
+ model_path=model_dir,
684
+ eager_mode=False,
685
+ color_dapi=False,
686
+ color_marker=False,
687
+ opt=None,
688
+ return_seg_intermediate=False,
689
+ seg_only=True,
690
+ )
691
+ region_data = compute_cell_results(images['Seg'], images.get('Marker'), resolution, version=version)
692
+
693
+ if start_x != 0 or start_y != 0:
694
+ for i in range(len(region_data['cells'])):
695
+ cell = decode_cell_data_v4(region_data['cells'][i]) if version == 4 else region_data['cells'][i]
696
+ for j in range(2):
697
+ cell['bbox'][j] = (cell['bbox'][j][0] + start_x, cell['bbox'][j][1] + start_y)
698
+ cell['centroid'] = (cell['centroid'][0] + start_x, cell['centroid'][1] + start_y)
699
+ for j in range(len(cell['boundary'])):
700
+ cell['boundary'][j] = (cell['boundary'][j][0] + start_x, cell['boundary'][j][1] + start_y)
701
+ region_data['cells'][i] = encode_cell_data_v4(cell) if version == 4 else cell
702
+
703
+ if data is None:
704
+ data = region_data
705
+ else:
706
+ data['cells'] += region_data['cells']
707
+
708
+ if region_data['settings']['default_marker_thresh'] is not None and region_data['settings']['default_marker_thresh'] != 0:
709
+ default_marker_thresh += region_data['settings']['default_marker_thresh']
710
+ count_marker_thresh += 1
711
+ if region_data['settings']['default_size_thresh'] != 0:
712
+ default_size_thresh += region_data['settings']['default_size_thresh']
713
+ count_size_thresh += 1
714
+
715
+ start_x += stride_x
716
+
717
+ start_x = 0
718
+ start_y += stride_y
665
719
 
666
720
  javabridge.kill_vm()
721
+
722
+ if count_marker_thresh == 0:
723
+ count_marker_thresh = 1
724
+ if count_size_thresh == 0:
725
+ count_size_thresh = 1
726
+ data['settings']['default_marker_thresh'] = round(default_marker_thresh / count_marker_thresh)
727
+ data['settings']['default_size_thresh'] = round(default_size_thresh / count_size_thresh)
728
+
729
+ return data