deepliif 1.1.10__py3-none-any.whl → 1.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,7 +12,7 @@ In the function <__init__>, you need to define four lists:
12
12
  -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
13
  -- self.model_names (str list): define networks used in our training.
14
14
  -- self.visual_names (str list): specify the images that you want to display and save.
15
- -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See CycleGAN_model.py for an usage.
16
16
 
17
17
  Now you can use the model class by specifying flag '--model dummy'.
18
18
  See our template model class 'template_model.py' for more details.
@@ -23,17 +23,23 @@ import itertools
23
23
  import importlib
24
24
  from functools import lru_cache
25
25
  from io import BytesIO
26
+ import json
27
+ import math
26
28
 
27
29
  import requests
28
30
  import torch
29
31
  from PIL import Image
32
+ Image.MAX_IMAGE_PIXELS = None
33
+
30
34
  import numpy as np
31
35
  from dask import delayed, compute
36
+ import openslide
32
37
 
33
38
  from deepliif.util import *
34
- from deepliif.util.util import tensor_to_pil, check_multi_scale
39
+ from deepliif.util.util import tensor_to_pil
35
40
  from deepliif.data import transform
36
- from deepliif.postprocessing import compute_results
41
+ from deepliif.postprocessing import compute_final_results, compute_cell_results
42
+ from deepliif.postprocessing import encode_cell_data_v4, decode_cell_data_v4
37
43
  from deepliif.options import Options, print_options
38
44
 
39
45
  from .base_model import BaseModel
@@ -115,14 +121,19 @@ def load_torchscript_model(model_pt_path, device):
115
121
 
116
122
 
117
123
 
118
- def load_eager_models(opt, devices):
124
+ def load_eager_models(opt, devices=None):
119
125
  # create a model given model and other options
120
126
  model = create_model(opt)
121
127
  # regular setup: load and print networks; create schedulers
122
128
  model.setup(opt)
123
129
 
124
130
  nets = {}
125
- for name in model.model_names:
131
+ if devices:
132
+ model_names = list(devices.keys())
133
+ else:
134
+ model_names = model.model_names
135
+
136
+ for name in model_names:#model.model_names:
126
137
  if isinstance(name, str):
127
138
  if '_' in name:
128
139
  net = getattr(model, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
@@ -138,7 +149,8 @@ def load_eager_models(opt, devices):
138
149
  net = net.module
139
150
 
140
151
  nets[name] = net
141
- nets[name].to(devices[name])
152
+ if devices:
153
+ nets[name].to(devices[name])
142
154
 
143
155
  return nets
144
156
 
@@ -154,8 +166,7 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
154
166
  """
155
167
  if opt is None:
156
168
  opt = get_opt(model_dir, mode=phase)
157
- opt.use_dp = False
158
- #print_options(opt)
169
+ opt.use_dp = False
159
170
 
160
171
  if opt.model == 'DeepLIIF':
161
172
  net_groups = [
@@ -170,12 +181,16 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
170
181
  net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
171
182
  else:
172
183
  net_groups = [(f'G_{i+1}',) for i in range(opt.modalities_no)]
184
+ elif opt.model == 'CycleGAN':
185
+ if opt.BtoA:
186
+ net_groups = [(f'GB_{i+1}',) for i in range(opt.modalities_no)]
187
+ else:
188
+ net_groups = [(f'GA_{i+1}',) for i in range(opt.modalities_no)]
173
189
  else:
174
190
  raise Exception(f'init_nets() not implemented for model {opt.model}')
175
191
 
176
192
  number_of_gpus_all = torch.cuda.device_count()
177
- number_of_gpus = len(opt.gpu_ids)
178
- #print(number_of_gpus)
193
+ number_of_gpus = min(len(opt.gpu_ids),number_of_gpus_all)
179
194
 
180
195
  if number_of_gpus > 0:
181
196
  mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
@@ -203,12 +218,13 @@ def compute_overlap(img_size, tile_size):
203
218
  return tile_size // 4
204
219
 
205
220
 
206
- def run_torchserve(img, model_path=None, eager_mode=False, opt=None):
221
+ def run_torchserve(img, model_path=None, eager_mode=False, opt=None, seg_only=False):
207
222
  """
208
223
  eager_mode: not used in this function; put in place to be consistent with run_dask
209
224
  so that run_wrapper() could call either this function or run_dask with
210
225
  same syntax
211
226
  opt: same as eager_mode
227
+ seg_only: same as eager_mode
212
228
  """
213
229
  buffer = BytesIO()
214
230
  torch.save(transform(img.resize((opt.scale_size, opt.scale_size))), buffer)
@@ -227,9 +243,10 @@ def run_torchserve(img, model_path=None, eager_mode=False, opt=None):
227
243
  return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
228
244
 
229
245
 
230
- def run_dask(img, model_path, eager_mode=False, opt=None):
246
+ def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
231
247
  model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
232
248
  nets = init_nets(model_dir, eager_mode, opt)
249
+ use_dask = True if opt.norm != 'spectral' else False
233
250
 
234
251
  if opt.input_no > 1 or opt.model == 'SDG':
235
252
  l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
@@ -238,45 +255,69 @@ def run_dask(img, model_path, eager_mode=False, opt=None):
238
255
  ts = transform(img.resize((opt.scale_size, opt.scale_size)))
239
256
 
240
257
 
241
- @delayed
242
- def forward(input, model):
243
- with torch.no_grad():
244
- return model(input.to(next(model.parameters()).device))
258
+ if use_dask:
259
+ @delayed
260
+ def forward(input, model):
261
+ with torch.no_grad():
262
+ return model(input.to(next(model.parameters()).device))
263
+ else: # some train settings like spectral norm some how in inference mode is not compatible with dask
264
+ def forward(input, model):
265
+ with torch.no_grad():
266
+ return model(input.to(next(model.parameters()).device))
245
267
 
246
268
  if opt.model == 'DeepLIIF':
269
+ weights = {
270
+ 'G51': 0.25, # IHC
271
+ 'G52': 0.25, # Hema
272
+ 'G53': 0.25, # DAPI
273
+ 'G54': 0.00, # Lap2
274
+ 'G55': 0.25, # Marker
275
+ }
276
+
247
277
  seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
278
+ if seg_only:
279
+ seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0}
248
280
 
249
281
  lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
282
+ if 'G4' not in seg_map:
283
+ lazy_gens['G4'] = forward(ts, nets['G4'])
250
284
  gens = compute(lazy_gens)[0]
251
285
 
252
286
  lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
253
- lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
287
+ if not seg_only or weights['G51'] != 0:
288
+ lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
254
289
  segs = compute(lazy_segs)[0]
255
290
 
256
- weights = {
257
- 'G51': 0.25, # IHC
258
- 'G52': 0.25, # Hema
259
- 'G53': 0.25, # DAPI
260
- 'G54': 0.00, # Lap2
261
- 'G55': 0.25, # Marker
262
- }
263
291
  seg = torch.stack([torch.mul(segs[k], weights[k]) for k in segs.keys()]).sum(dim=0)
264
292
 
265
- res = {k: tensor_to_pil(v) for k, v in gens.items()}
293
+ if seg_only:
294
+ res = {'G4': tensor_to_pil(gens['G4'])} if 'G4' in gens else {}
295
+ else:
296
+ res = {k: tensor_to_pil(v) for k, v in gens.items()}
297
+ res.update({k: tensor_to_pil(v) for k, v in segs.items()})
266
298
  res['G5'] = tensor_to_pil(seg)
267
299
 
268
300
  return res
269
- elif opt.model in ['DeepLIIFExt','SDG']:
270
- seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
301
+ elif opt.model in ['DeepLIIFExt','SDG','CycleGAN']:
302
+ if opt.model == 'CycleGAN':
303
+ seg_map = {f'GB_{i+1}':None for i in range(opt.modalities_no)} if opt.BtoA else {f'GA_{i+1}':None for i in range(opt.modalities_no)}
304
+ else:
305
+ seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
271
306
 
272
- lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
273
- gens = compute(lazy_gens)[0]
307
+ if use_dask:
308
+ lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
309
+ gens = compute(lazy_gens)[0]
310
+ else:
311
+ gens = {k: forward(ts, nets[k]) for k in seg_map}
274
312
 
275
313
  res = {k: tensor_to_pil(v) for k, v in gens.items()}
276
314
 
277
315
  if opt.seg_gen:
278
- lazy_segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
279
- segs = compute(lazy_segs)[0]
316
+ if use_dask:
317
+ lazy_segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
318
+ segs = compute(lazy_segs)[0]
319
+ else:
320
+ segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
280
321
  res.update({k: tensor_to_pil(v) for k, v in segs.items()})
281
322
 
282
323
  return res
@@ -284,33 +325,37 @@ def run_dask(img, model_path, eager_mode=False, opt=None):
284
325
  raise Exception(f'run_dask() not fully implemented for {opt.model}')
285
326
 
286
327
 
287
- def is_empty_old(tile):
288
- # return True if np.mean(np.array(tile) - np.array(mean_background_val)) < 40 else False
289
- if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
290
- return all([True if calculate_background_area(t) > 98 else False for t in tile])
291
- else:
292
- return True if calculate_background_area(tile) > 98 else False
293
-
294
-
295
328
  def is_empty(tile):
329
+ thresh = 15
296
330
  if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
297
- return all([True if np.max(image_variance_rgb(tile)) < 15 else False for t in tile])
331
+ return all([True if np.max(image_variance_rgb(t)) < thresh else False for t in tile])
298
332
  else:
299
- return True if np.max(image_variance_rgb(tile)) < 15 else False
333
+ return True if np.max(image_variance_rgb(tile)) < thresh else False
300
334
 
301
335
 
302
- def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None):
336
+ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None, seg_only=False):
303
337
  if opt.model == 'DeepLIIF':
304
338
  if is_empty(tile):
305
- return {
306
- 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
307
- 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
308
- 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
309
- 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
310
- 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0))
311
- }
339
+ if seg_only:
340
+ return {
341
+ 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
342
+ 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
343
+ }
344
+ else :
345
+ return {
346
+ 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
347
+ 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
348
+ 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
349
+ 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
350
+ 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
351
+ 'G51': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
352
+ 'G52': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
353
+ 'G53': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
354
+ 'G54': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
355
+ 'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
356
+ }
312
357
  else:
313
- return run_fn(tile, model_path, eager_mode, opt)
358
+ return run_fn(tile, model_path, eager_mode, opt, seg_only)
314
359
  elif opt.model in ['DeepLIIFExt', 'SDG']:
315
360
  if is_empty(tile):
316
361
  res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
@@ -318,197 +363,109 @@ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None):
318
363
  return res
319
364
  else:
320
365
  return run_fn(tile, model_path, eager_mode, opt)
366
+ elif opt.model in ['CycleGAN']:
367
+ if is_empty(tile):
368
+ net_names = ['GB_{i+1}' for i in range(opt.modalities_no)] if opt.BtoA else [f'GA_{i+1}' for i in range(opt.modalities_no)]
369
+ res = {net_name: Image.new(mode='RGB', size=(512, 512)) for net_name in net_names}
370
+ return res
371
+ else:
372
+ return run_fn(tile, model_path, eager_mode, opt)
321
373
  else:
322
374
  raise Exception(f'run_wrapper() not implemented for model {opt.model}')
323
375
 
324
376
 
325
- def inference_old(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
326
- color_dapi=False, color_marker=False):
327
-
328
- tiles = list(generate_tiles(img, tile_size, overlap_size))
377
+ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
378
+ eager_mode=False, color_dapi=False, color_marker=False, opt=None,
379
+ return_seg_intermediate=False, seg_only=False):
380
+ if not opt:
381
+ opt = get_opt(model_path)
382
+ #print_options(opt)
329
383
 
330
384
  run_fn = run_torchserve if use_torchserve else run_dask
331
- # res = [Tile(t.i, t.j, run_fn(t.img, model_path)) for t in tiles]
332
- res = [Tile(t.i, t.j, run_wrapper(t.img, run_fn, model_path, eager_mode)) for t in tiles]
333
-
334
- def get_net_tiles(n):
335
- return [Tile(t.i, t.j, t.img[n]) for t in res]
336
-
337
- images = {}
338
-
339
- images['Hema'] = stitch(get_net_tiles('G1'), tile_size, overlap_size).resize(img.size)
340
-
341
- # images['DAPI'] = stitch(
342
- # [Tile(t.i, t.j, adjust_background_tile(dt.img))
343
- # for t, dt in zip(tiles, get_net_tiles('G2'))],
344
- # tile_size, overlap_size).resize(img.size)
345
- # dapi_pix = np.array(images['DAPI'])
346
- # dapi_pix[:, :, 0] = 0
347
- # images['DAPI'] = Image.fromarray(dapi_pix)
348
-
349
- images['DAPI'] = stitch(get_net_tiles('G2'), tile_size, overlap_size).resize(img.size)
350
- dapi_pix = np.array(images['DAPI'].convert('L').convert('RGB'))
351
- if color_dapi:
352
- dapi_pix[:, :, 0] = 0
353
- images['DAPI'] = Image.fromarray(dapi_pix)
354
- images['Lap2'] = stitch(get_net_tiles('G3'), tile_size, overlap_size).resize(img.size)
355
- images['Marker'] = stitch(get_net_tiles('G4'), tile_size, overlap_size).resize(img.size)
356
- marker_pix = np.array(images['Marker'].convert('L').convert('RGB'))
357
- if color_marker:
358
- marker_pix[:, :, 2] = 0
359
- images['Marker'] = Image.fromarray(marker_pix)
360
385
 
361
- # images['Marker'] = stitch(
362
- # [Tile(t.i, t.j, kt.img)
363
- # for t, kt in zip(tiles, get_net_tiles('G4'))],
364
- # tile_size, overlap_size).resize(img.size)
365
-
366
- images['Seg'] = stitch(get_net_tiles('G5'), tile_size, overlap_size).resize(img.size)
367
-
368
- return images
386
+ if opt.model == 'SDG':
387
+ # SDG could have multiple input images/modalities, hence the input could be a rectangle.
388
+ # We split the input to get each modality image then create tiles for each set of input images.
389
+ w, h = int(img.width / opt.input_no), img.height
390
+ orig = [img.crop((w * i, 0, w * (i+1), h)) for i in range(opt.input_no)]
391
+ else:
392
+ # Otherwise expect a single input image, which is used directly.
393
+ orig = img
369
394
 
395
+ tiler = InferenceTiler(orig, tile_size, overlap_size)
396
+ for tile in tiler:
397
+ tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt, seg_only))
398
+ results = tiler.results()
370
399
 
371
- def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
372
- color_dapi=False, color_marker=False, opt=None):
373
- if not opt:
374
- opt = get_opt(model_path)
375
- #print_options(opt)
376
-
377
400
  if opt.model == 'DeepLIIF':
378
- rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
379
-
380
- run_fn = run_torchserve if use_torchserve else run_dask
381
-
382
- images = {}
383
- d_modality2net = {'Hema':'G1',
384
- 'DAPI':'G2',
385
- 'Lap2':'G3',
386
- 'Marker':'G4',
387
- 'Seg':'G5'}
401
+ if seg_only:
402
+ images = {'Seg': results['G5']}
403
+ if 'G4' in results:
404
+ images.update({'Marker': results['G4']})
405
+ else:
406
+ images = {
407
+ 'Hema': results['G1'],
408
+ 'DAPI': results['G2'],
409
+ 'Lap2': results['G3'],
410
+ 'Marker': results['G4'],
411
+ 'Seg': results['G5'],
412
+ }
388
413
 
389
- for k in d_modality2net.keys():
390
- images[k] = create_image_for_stitching(tile_size, rows, cols)
391
-
392
- for i in range(cols):
393
- for j in range(rows):
394
- tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
395
- res = run_wrapper(tile, run_fn, model_path, eager_mode, opt)
396
-
397
- for modality_name, net_name in d_modality2net.items():
398
- stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
414
+ if return_seg_intermediate and not seg_only:
415
+ images.update({'IHC_s':results['G51'],
416
+ 'Hema_s':results['G52'],
417
+ 'DAPI_s':results['G53'],
418
+ 'Lap2_s':results['G54'],
419
+ 'Marker_s':results['G55'],})
399
420
 
400
- for modality_name, output_img in images.items():
401
- images[modality_name] = output_img.resize(img.size)
402
-
403
- if color_dapi:
421
+ if color_dapi and not seg_only:
404
422
  matrix = ( 0, 0, 0, 0,
405
423
  299/1000, 587/1000, 114/1000, 0,
406
424
  299/1000, 587/1000, 114/1000, 0)
407
425
  images['DAPI'] = images['DAPI'].convert('RGB', matrix)
408
-
409
- if color_marker:
426
+ if color_marker and not seg_only:
410
427
  matrix = (299/1000, 587/1000, 114/1000, 0,
411
428
  299/1000, 587/1000, 114/1000, 0,
412
429
  0, 0, 0, 0)
413
430
  images['Marker'] = images['Marker'].convert('RGB', matrix)
414
-
415
431
  return images
416
-
432
+
417
433
  elif opt.model == 'DeepLIIFExt':
418
- #param_dict = read_train_options(model_path)
419
- #modalities_no = int(param_dict['modalities_no']) if param_dict else 4
420
- #seg_gen = (param_dict['seg_gen'] == 'True') if param_dict else True
421
-
422
-
423
- rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
424
- run_fn = run_torchserve if use_torchserve else run_dask
425
-
426
- def get_net_tiles(n):
427
- return [Tile(t.i, t.j, t.img[n]) for t in res]
428
-
429
- images = {}
430
- d_modality2net = {f'mod{i}':f'G_{i}' for i in range(1, opt.modalities_no + 1)}
434
+ images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
431
435
  if opt.seg_gen:
432
- d_modality2net.update({f'Seg{i}':f'GS_{i}' for i in range(1, opt.modalities_no + 1)})
433
-
434
- for k in d_modality2net.keys():
435
- images[k] = create_image_for_stitching(tile_size, rows, cols)
436
-
437
- for i in range(cols):
438
- for j in range(rows):
439
- tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
440
- res = run_wrapper(tile, run_fn, model_path, eager_mode, opt)
441
-
442
- for modality_name, net_name in d_modality2net.items():
443
- stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
444
-
445
- for modality_name, output_img in images.items():
446
- images[modality_name] = output_img.resize(img.size)
447
-
436
+ images.update({f'Seg{i}': results[f'GS_{i}'] for i in range(1, opt.modalities_no + 1)})
448
437
  return images
449
-
438
+
450
439
  elif opt.model == 'SDG':
451
- # SDG could have multiple input images / modalities
452
- # the input hence could be a rectangle
453
- # we split the input to get each modality image one by one
454
- # then create tiles for each of the modality images
455
- # tile_pair is a list that contains the tiles at the given location for each modality image
456
- # l_tile_pair is a list of tile_pair that covers all locations
457
- # for inference, each tile_pair is used to get the output at the given location
458
- w, h = img.size
459
- w2 = int(w / opt.input_no)
460
-
461
- l_img = []
462
- for i in range(opt.input_no):
463
- img_i = img.crop((w2 * i, 0, w2 * (i+1), h))
464
- rescaled_img_i, rows, cols = format_image_for_tiling(img_i, tile_size, overlap_size)
465
- l_img.append(rescaled_img_i)
466
-
467
- run_fn = run_torchserve if use_torchserve else run_dask
468
-
469
- images = {}
470
- d_modality2net = {f'mod{i}':f'G_{i}' for i in range(1, opt.modalities_no + 1)}
471
- for k in d_modality2net.keys():
472
- images[k] = create_image_for_stitching(tile_size, rows, cols)
473
-
474
- for i in range(cols):
475
- for j in range(rows):
476
- tile_pair = [extract_tile(rescaled, tile_size, overlap_size, i, j) for rescaled in l_img]
477
- res = run_wrapper(tile_pair, run_fn, model_path, eager_mode, opt)
478
-
479
- for modality_name, net_name in d_modality2net.items():
480
- stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
481
-
482
- for modality_name, output_img in images.items():
483
- images[modality_name] = output_img.resize((w2,w2))
484
-
440
+ images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
485
441
  return images
486
-
442
+
487
443
  else:
488
- raise Exception(f'inference() not implemented for model {opt.model}')
444
+ #raise Exception(f'inference() not implemented for model {opt.model}')
445
+ return results # return result images with default key names (i.e., net names)
489
446
 
490
447
 
491
- def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='auto', marker_thresh='auto', size_thresh_upper=None):
448
+ def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='default', marker_thresh=None, size_thresh_upper=None):
492
449
  if model == 'DeepLIIF':
493
450
  resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
494
- overlay, refined, scoring = compute_results(np.array(orig), np.array(images['Seg']),
495
- np.array(images['Marker'].convert('L')) if 'Marker' in images else None,
496
- resolution, seg_thresh, size_thresh, marker_thresh, size_thresh_upper)
451
+ overlay, refined, scoring = compute_final_results(
452
+ orig, images['Seg'], images.get('Marker'), resolution,
453
+ size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
497
454
  processed_images = {}
498
455
  processed_images['SegOverlaid'] = Image.fromarray(overlay)
499
456
  processed_images['SegRefined'] = Image.fromarray(refined)
500
457
  return processed_images, scoring
501
458
 
502
- elif model == 'DeepLIIFExt':
459
+ elif model in ['DeepLIIFExt','SDG']:
503
460
  resolution = '40x' if tile_size > 768 else ('20x' if tile_size > 384 else '10x')
504
461
  processed_images = {}
505
462
  scoring = {}
506
463
  for img_name in list(images.keys()):
507
464
  if 'Seg' in img_name:
508
465
  seg_img = images[img_name]
509
- overlay, refined, score = compute_results(np.array(orig), np.array(images[img_name]),
510
- None, resolution,
511
- seg_thresh, size_thresh, marker_thresh, size_thresh_upper)
466
+ overlay, refined, score = compute_final_results(
467
+ orig, images[img_name], None, resolution,
468
+ size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
512
469
 
513
470
  processed_images[img_name + '_Overlaid'] = Image.fromarray(overlay)
514
471
  processed_images[img_name + '_Refined'] = Image.fromarray(refined)
@@ -520,7 +477,8 @@ def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='aut
520
477
 
521
478
 
522
479
  def infer_modalities(img, tile_size, model_dir, eager_mode=False,
523
- color_dapi=False, color_marker=False, opt=None):
480
+ color_dapi=False, color_marker=False, opt=None,
481
+ return_seg_intermediate=False, seg_only=False):
524
482
  """
525
483
  This function is used to infer modalities for the given image using a trained model.
526
484
  :param img: The input image.
@@ -533,11 +491,6 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
533
491
  opt.use_dp = False
534
492
  #print_options(opt)
535
493
 
536
- if not tile_size:
537
- tile_size = check_multi_scale(Image.open('./images/target.png').convert('L'),
538
- img.convert('L'))
539
- tile_size = int(tile_size)
540
-
541
494
  # for those with multiple input modalities, find the correct size to calculate overlap_size
542
495
  input_no = opt.input_no if hasattr(opt, 'input_no') else 1
543
496
  img_size = (img.size[0] / input_no, img.size[1]) # (width, height)
@@ -545,23 +498,30 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
545
498
  images = inference(
546
499
  img,
547
500
  tile_size=tile_size,
548
- overlap_size=compute_overlap(img_size, tile_size),
501
+ #overlap_size=compute_overlap(img_size, tile_size),
502
+ overlap_size=tile_size//16,
549
503
  model_path=model_dir,
550
504
  eager_mode=eager_mode,
551
505
  color_dapi=color_dapi,
552
506
  color_marker=color_marker,
553
- opt=opt
507
+ opt=opt,
508
+ return_seg_intermediate=return_seg_intermediate,
509
+ seg_only=seg_only
554
510
  )
555
-
511
+
556
512
  if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
557
513
  post_images, scoring = postprocess(img, images, tile_size, opt.model)
558
514
  images = {**images, **post_images}
515
+ if seg_only:
516
+ delete_keys = [k for k in images.keys() if 'Seg' not in k]
517
+ for name in delete_keys:
518
+ del images[name]
559
519
  return images, scoring
560
520
  else:
561
521
  return images, None
562
522
 
563
523
 
564
- def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000):
524
+ def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000, color_dapi=False, color_marker=False, seg_intermediate=False, seg_only=False):
565
525
  """
566
526
  This function infers modalities and segmentation mask for the given WSI image. It
567
527
 
@@ -573,35 +533,197 @@ def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size,
573
533
  :param region_size: The size of each individual region to be processed at once.
574
534
  :return:
575
535
  """
576
- results_dir = os.path.join(output_dir, filename)
536
+ basename, _ = os.path.splitext(filename)
537
+ results_dir = os.path.join(output_dir, basename)
577
538
  if not os.path.exists(results_dir):
578
539
  os.makedirs(results_dir)
579
540
  size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(os.path.join(input_dir, filename))
580
- print(filename, size_x, size_y, size_z, size_c, size_t, pixel_type)
541
+ rescale = (pixel_type != 'uint8')
542
+ print(filename, size_x, size_y, size_z, size_c, size_t, pixel_type, flush=True)
543
+
581
544
  results = {}
582
- start_x, start_y = 0, 0
583
- while start_x < size_x:
584
- while start_y < size_y:
585
- print(start_x, start_y)
586
- region_XYWH = (start_x, start_y, min(region_size, size_x - start_x), min(region_size, size_y - start_y))
587
- region = read_bioformats_image_with_reader(os.path.join(input_dir, filename), region=region_XYWH)
588
-
589
- region_modalities, region_scoring = infer_modalities(Image.fromarray((region * 255).astype(np.uint8)), tile_size, model_dir)
590
-
591
- for name, img in region_modalities.items():
592
- if name not in results:
593
- results[name] = np.zeros((size_y, size_x, 3), dtype=np.uint8)
594
- results[name][region_XYWH[1]: region_XYWH[1] + region_XYWH[3],
595
- region_XYWH[0]: region_XYWH[0] + region_XYWH[2]] = np.array(img)
596
- start_y += region_size
597
- start_y = 0
598
- start_x += region_size
599
-
600
- write_results_to_pickle_file(os.path.join(results_dir, "results.pickle"), results)
545
+ scoring = None
546
+
547
+ # javabridge already set up from previous call to get_information()
548
+ with bioformats.ImageReader(os.path.join(input_dir, filename)) as reader:
549
+ start_x, start_y = 0, 0
550
+
551
+ while start_x < size_x:
552
+ while start_y < size_y:
553
+ print(start_x, start_y, flush=True)
554
+ region_XYWH = (start_x, start_y, min(region_size, size_x - start_x), min(region_size, size_y - start_y))
555
+ region = reader.read(XYWH=region_XYWH, rescale=rescale)
556
+ img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
557
+
558
+ region_modalities, region_scoring = infer_modalities(img, tile_size, model_dir, color_dapi=color_dapi, color_marker=color_marker, return_seg_intermediate=seg_intermediate, seg_only=seg_only)
559
+ if region_scoring is not None:
560
+ if scoring is None:
561
+ scoring = {
562
+ 'num_pos': region_scoring['num_pos'],
563
+ 'num_neg': region_scoring['num_neg'],
564
+ }
565
+ else:
566
+ scoring['num_pos'] += region_scoring['num_pos']
567
+ scoring['num_neg'] += region_scoring['num_neg']
568
+
569
+ for name, img in region_modalities.items():
570
+ if name not in results:
571
+ results[name] = np.zeros((size_y, size_x, 3), dtype=np.uint8)
572
+ results[name][region_XYWH[1]: region_XYWH[1] + region_XYWH[3],
573
+ region_XYWH[0]: region_XYWH[0] + region_XYWH[2]] = np.array(img)
574
+ start_y += region_size
575
+ start_y = 0
576
+ start_x += region_size
577
+
578
+ # write_results_to_pickle_file(os.path.join(results_dir, "results.pickle"), results)
601
579
  # read_results_from_pickle_file(os.path.join(results_dir, "results.pickle"))
602
580
 
603
581
  for name, img in results.items():
604
- write_big_tiff_file(os.path.join(results_dir, filename.replace('.svs', '_' + name + '.ome.tiff')), img,
605
- tile_size)
582
+ write_big_tiff_file(os.path.join(results_dir, f'{basename}_{name}.ome.tiff'), img, tile_size)
583
+
584
+ if scoring is not None:
585
+ scoring['num_total'] = scoring['num_pos'] + scoring['num_neg']
586
+ scoring['percent_pos'] = round(scoring['num_pos'] / scoring['num_total'] * 100, 1) if scoring['num_pos'] > 0 else 0
587
+ with open(os.path.join(results_dir, f'{basename}.json'), 'w') as f:
588
+ json.dump(scoring, f, indent=2)
606
589
 
607
590
  javabridge.kill_vm()
591
+
592
+
593
+ def get_wsi_resolution(filename):
594
+ """
595
+ Use OpenSlide to get the resolution (magnification) of the slide
596
+ and the corresponding tile size to use by default for DeepLIIF.
597
+ If it cannot be found, return (None, None) instead.
598
+
599
+ Parameters
600
+ ----------
601
+ filename : str
602
+ Full path to the file.
603
+
604
+ Returns
605
+ -------
606
+ str :
607
+ Magnification (objective power) as found by OpenSlide.
608
+ int :
609
+ Corresponding tile size for DeepLIIF.
610
+ """
611
+ try:
612
+ image = openslide.OpenSlide(filename)
613
+ mag = image.properties.get(openslide.PROPERTY_NAME_OBJECTIVE_POWER)
614
+ tile_size = round((float(mag) / 40) * 512)
615
+ return mag, tile_size
616
+ except Exception as e:
617
+ return None, None
618
+
619
+
620
+ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False):
621
+ """
622
+ Perform inference on a slide and get the results individual cell data.
623
+
624
+ Parameters
625
+ ----------
626
+ filename : str
627
+ Full path to the file.
628
+ model_dir : str
629
+ Full path to the directory with the DeepLIIF model files.
630
+ tile_size : int
631
+ Size of tiles to extract and perform inference on.
632
+ region_size : int
633
+ Maximum size to split the slide for processing.
634
+ version : int
635
+ Version of cell data to return (3 or 4).
636
+ print_log : bool
637
+ Whether or not to print updates while processing.
638
+
639
+ Returns
640
+ -------
641
+ dict :
642
+ Individual cell data and associated values.
643
+ """
644
+
645
+ def print_info(*args):
646
+ if print_log:
647
+ print(*args, flush=True)
648
+
649
+ resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
650
+
651
+ size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
652
+ rescale = (pixel_type != 'uint8')
653
+ print_info('Info:', size_x, size_y, size_z, size_c, size_t, pixel_type)
654
+
655
+ num_regions_x = math.ceil(size_x / region_size)
656
+ num_regions_y = math.ceil(size_y / region_size)
657
+ stride_x = math.ceil(size_x / num_regions_x)
658
+ stride_y = math.ceil(size_y / num_regions_y)
659
+ print_info('Strides:', stride_x, stride_y)
660
+
661
+ data = None
662
+ default_marker_thresh, count_marker_thresh = 0, 0
663
+ default_size_thresh, count_size_thresh = 0, 0
664
+
665
+ # javabridge already set up from previous call to get_information()
666
+ with bioformats.ImageReader(filename) as reader:
667
+ start_x, start_y = 0, 0
668
+
669
+ while start_y < size_y:
670
+ while start_x < size_x:
671
+ region_XYWH = (start_x, start_y, min(stride_x, size_x-start_x), min(stride_y, size_y-start_y))
672
+ print_info('Region:', region_XYWH)
673
+
674
+ region = reader.read(XYWH=region_XYWH, rescale=rescale)
675
+ print_info(region.shape, region.dtype)
676
+ img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
677
+ print_info(img.size, img.mode)
678
+
679
+ images = inference(
680
+ img,
681
+ tile_size=tile_size,
682
+ overlap_size=tile_size//16,
683
+ model_path=model_dir,
684
+ eager_mode=False,
685
+ color_dapi=False,
686
+ color_marker=False,
687
+ opt=None,
688
+ return_seg_intermediate=False,
689
+ seg_only=True,
690
+ )
691
+ region_data = compute_cell_results(images['Seg'], images.get('Marker'), resolution, version=version)
692
+
693
+ if start_x != 0 or start_y != 0:
694
+ for i in range(len(region_data['cells'])):
695
+ cell = decode_cell_data_v4(region_data['cells'][i]) if version == 4 else region_data['cells'][i]
696
+ for j in range(2):
697
+ cell['bbox'][j] = (cell['bbox'][j][0] + start_x, cell['bbox'][j][1] + start_y)
698
+ cell['centroid'] = (cell['centroid'][0] + start_x, cell['centroid'][1] + start_y)
699
+ for j in range(len(cell['boundary'])):
700
+ cell['boundary'][j] = (cell['boundary'][j][0] + start_x, cell['boundary'][j][1] + start_y)
701
+ region_data['cells'][i] = encode_cell_data_v4(cell) if version == 4 else cell
702
+
703
+ if data is None:
704
+ data = region_data
705
+ else:
706
+ data['cells'] += region_data['cells']
707
+
708
+ if region_data['settings']['default_marker_thresh'] is not None and region_data['settings']['default_marker_thresh'] != 0:
709
+ default_marker_thresh += region_data['settings']['default_marker_thresh']
710
+ count_marker_thresh += 1
711
+ if region_data['settings']['default_size_thresh'] != 0:
712
+ default_size_thresh += region_data['settings']['default_size_thresh']
713
+ count_size_thresh += 1
714
+
715
+ start_x += stride_x
716
+
717
+ start_x = 0
718
+ start_y += stride_y
719
+
720
+ javabridge.kill_vm()
721
+
722
+ if count_marker_thresh == 0:
723
+ count_marker_thresh = 1
724
+ if count_size_thresh == 0:
725
+ count_size_thresh = 1
726
+ data['settings']['default_marker_thresh'] = round(default_marker_thresh / count_marker_thresh)
727
+ data['settings']['default_size_thresh'] = round(default_size_thresh / count_size_thresh)
728
+
729
+ return data