deepliif 1.1.13__py3-none-any.whl → 1.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,6 +25,8 @@ from functools import lru_cache
25
25
  from io import BytesIO
26
26
  import json
27
27
  import math
28
+ import importlib.metadata
29
+ import pathlib
28
30
 
29
31
  import requests
30
32
  import torch
@@ -37,7 +39,7 @@ from dask import delayed, compute
37
39
  from deepliif.util import *
38
40
  from deepliif.util.util import tensor_to_pil
39
41
  from deepliif.data import transform
40
- from deepliif.postprocessing import compute_final_results, compute_cell_results
42
+ from deepliif.postprocessing import compute_final_results, compute_cell_results, to_array
41
43
  from deepliif.postprocessing import encode_cell_data_v4, decode_cell_data_v4
42
44
  from deepliif.options import Options, print_options
43
45
 
@@ -167,7 +169,7 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
167
169
  opt = get_opt(model_dir, mode=phase)
168
170
  opt.use_dp = False
169
171
 
170
- if opt.model == 'DeepLIIF':
172
+ if opt.model in ['DeepLIIF','DeepLIIFKD']:
171
173
  net_groups = [
172
174
  ('G1', 'G52'),
173
175
  ('G2', 'G53'),
@@ -217,13 +219,15 @@ def compute_overlap(img_size, tile_size):
217
219
  return tile_size // 4
218
220
 
219
221
 
220
- def run_torchserve(img, model_path=None, eager_mode=False, opt=None, seg_only=False):
222
+ def run_torchserve(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, seg_weights=None, use_dask=True, output_tensor=False):
221
223
  """
222
224
  eager_mode: not used in this function; put in place to be consistent with run_dask
223
225
  so that run_wrapper() could call either this function or run_dask with
224
226
  same syntax
225
227
  opt: same as eager_mode
226
228
  seg_only: same as eager_mode
229
+ seg_weights: same as eager_mode
230
+ nets: same as eager_mode
227
231
  """
228
232
  buffer = BytesIO()
229
233
  torch.save(transform(img.resize((opt.scale_size, opt.scale_size))), buffer)
@@ -242,16 +246,28 @@ def run_torchserve(img, model_path=None, eager_mode=False, opt=None, seg_only=Fa
242
246
  return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
243
247
 
244
248
 
245
- def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
246
- model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
247
- nets = init_nets(model_dir, eager_mode, opt)
248
- use_dask = True if opt.norm != 'spectral' else False
249
+ def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, seg_weights=None, use_dask=True, output_tensor=False):
250
+ """
251
+ Provide either the model path or the networks object.
252
+
253
+ `eager_mode` is only applicable if model_path is provided.
254
+ """
255
+ assert model_path is not None or nets is not None, 'Provide either the model path or the networks object.'
256
+ if nets is None:
257
+ model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
258
+ nets = init_nets(model_dir, eager_mode, opt)
249
259
 
250
- if opt.input_no > 1 or opt.model == 'SDG':
251
- l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
252
- ts = torch.cat(l_ts, dim=1)
260
+ if use_dask: # check if use_dask should be overwritten
261
+ use_dask = True if opt.norm != 'spectral' else False
262
+
263
+ if isinstance(img,torch.Tensor): # if img input is already a tensor, pass
264
+ ts = img
253
265
  else:
254
- ts = transform(img.resize((opt.scale_size, opt.scale_size)))
266
+ if opt.input_no > 1 or opt.model == 'SDG':
267
+ l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
268
+ ts = torch.cat(l_ts, dim=1)
269
+ else:
270
+ ts = transform(img.resize((opt.scale_size, opt.scale_size)))
255
271
 
256
272
 
257
273
  if use_dask:
@@ -264,14 +280,23 @@ def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
264
280
  with torch.no_grad():
265
281
  return model(input.to(next(model.parameters()).device))
266
282
 
267
- if opt.model == 'DeepLIIF':
268
- weights = {
269
- 'G51': 0.25, # IHC
270
- 'G52': 0.25, # Hema
271
- 'G53': 0.25, # DAPI
272
- 'G54': 0.00, # Lap2
273
- 'G55': 0.25, # Marker
274
- }
283
+ if opt.model in ['DeepLIIF','DeepLIIFKD']:
284
+ if seg_weights is None:
285
+ weights = {
286
+ 'G51': 0.25, # IHC
287
+ 'G52': 0.25, # Hema
288
+ 'G53': 0.25, # DAPI
289
+ 'G54': 0.00, # Lap2
290
+ 'G55': 0.25, # Marker
291
+ }
292
+ else:
293
+ weights = {
294
+ 'G51': seg_weights[0], # IHC
295
+ 'G52': seg_weights[1], # Hema
296
+ 'G53': seg_weights[2], # DAPI
297
+ 'G54': seg_weights[3], # Lap2
298
+ 'G55': seg_weights[4], # Marker
299
+ }
275
300
 
276
301
  seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
277
302
  if seg_only:
@@ -282,19 +307,27 @@ def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
282
307
  lazy_gens['G4'] = forward(ts, nets['G4'])
283
308
  gens = compute(lazy_gens)[0]
284
309
 
285
- lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
310
+ lazy_segs = {v: forward(gens[k], nets[v]) for k, v in seg_map.items()}
286
311
  if not seg_only or weights['G51'] != 0:
287
- lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
312
+ lazy_segs['G51'] = forward(ts, nets['G51'])
288
313
  segs = compute(lazy_segs)[0]
289
-
290
- seg = torch.stack([torch.mul(segs[k], weights[k]) for k in segs.keys()]).sum(dim=0)
291
-
292
- if seg_only:
293
- res = {'G4': tensor_to_pil(gens['G4'])} if 'G4' in gens else {}
314
+
315
+ device = next(nets['G1'].parameters()).device # take the device of the first net and move all outputs there for seg aggregation
316
+ seg = torch.stack([torch.mul(segs[k].to(device), weights[k]) for k in segs.keys()]).sum(dim=0)
317
+
318
+ if output_tensor:
319
+ if seg_only:
320
+ res = {'G4': gens['G4']} if 'G4' in gens else {}
321
+ else:
322
+ res = {**gens, **segs}
323
+ res['G5'] = seg
294
324
  else:
295
- res = {k: tensor_to_pil(v) for k, v in gens.items()}
296
- res.update({k: tensor_to_pil(v) for k, v in segs.items()})
297
- res['G5'] = tensor_to_pil(seg)
325
+ if seg_only:
326
+ res = {'G4': tensor_to_pil(gens['G4'].to(torch.device('cpu')))} if 'G4' in gens else {}
327
+ else:
328
+ res = {k: tensor_to_pil(v.to(torch.device('cpu'))) for k, v in gens.items()}
329
+ res.update({k: tensor_to_pil(v.to(torch.device('cpu'))) for k, v in segs.items()})
330
+ res['G5'] = tensor_to_pil(seg.to(torch.device('cpu')))
298
331
 
299
332
  return res
300
333
  elif opt.model in ['DeepLIIFExt','SDG','CycleGAN']:
@@ -327,13 +360,13 @@ def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
327
360
  def is_empty(tile):
328
361
  thresh = 15
329
362
  if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
330
- return all([True if np.max(image_variance_rgb(t)) < thresh else False for t in tile])
363
+ return all([True if image_variance_gray(t) < thresh else False for t in tile])
331
364
  else:
332
- return True if np.max(image_variance_rgb(tile)) < thresh else False
365
+ return True if image_variance_gray(tile) < thresh else False
333
366
 
334
367
 
335
- def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None, seg_only=False):
336
- if opt.model == 'DeepLIIF':
368
+ def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, seg_weights=None, use_dask=True, output_tensor=False):
369
+ if opt.model in ['DeepLIIF','DeepLIIFKD']:
337
370
  if is_empty(tile):
338
371
  if seg_only:
339
372
  return {
@@ -354,31 +387,38 @@ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None, seg_only=F
354
387
  'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
355
388
  }
356
389
  else:
357
- return run_fn(tile, model_path, eager_mode, opt, seg_only)
390
+ return run_fn(tile, model_path, None, eager_mode, opt, seg_only, seg_weights)
358
391
  elif opt.model in ['DeepLIIFExt', 'SDG']:
359
392
  if is_empty(tile):
360
393
  res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
361
394
  res.update({'GS_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)})
362
395
  return res
363
396
  else:
364
- return run_fn(tile, model_path, eager_mode, opt)
397
+ return run_fn(tile, model_path, None, eager_mode, opt)
365
398
  elif opt.model in ['CycleGAN']:
366
399
  if is_empty(tile):
367
400
  net_names = ['GB_{i+1}' for i in range(opt.modalities_no)] if opt.BtoA else [f'GA_{i+1}' for i in range(opt.modalities_no)]
368
401
  res = {net_name: Image.new(mode='RGB', size=(512, 512)) for net_name in net_names}
369
402
  return res
370
403
  else:
371
- return run_fn(tile, model_path, eager_mode, opt)
404
+ return run_fn(tile, model_path, None, eager_mode, opt)
372
405
  else:
373
406
  raise Exception(f'run_wrapper() not implemented for model {opt.model}')
374
407
 
375
408
 
376
409
  def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
377
410
  eager_mode=False, color_dapi=False, color_marker=False, opt=None,
378
- return_seg_intermediate=False, seg_only=False):
411
+ return_seg_intermediate=False, seg_only=False, seg_weights=None, opt_args={}):
412
+ """
413
+ opt_args: a dictionary of key and values to add/overwrite to opt
414
+ """
379
415
  if not opt:
380
416
  opt = get_opt(model_path)
381
417
  #print_options(opt)
418
+
419
+ for k,v in opt_args.items():
420
+ setattr(opt,k,v)
421
+ #print_options(opt)
382
422
 
383
423
  run_fn = run_torchserve if use_torchserve else run_dask
384
424
 
@@ -393,10 +433,11 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
393
433
 
394
434
  tiler = InferenceTiler(orig, tile_size, overlap_size)
395
435
  for tile in tiler:
396
- tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt, seg_only))
436
+ tiler.stitch(run_wrapper(tile, run_fn, model_path, None, eager_mode, opt, seg_only, seg_weights))
437
+
397
438
  results = tiler.results()
398
439
 
399
- if opt.model == 'DeepLIIF':
440
+ if opt.model in ['DeepLIIF','DeepLIIFKD']:
400
441
  if seg_only:
401
442
  images = {'Seg': results['G5']}
402
443
  if 'G4' in results:
@@ -445,7 +486,7 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
445
486
 
446
487
 
447
488
  def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='default', marker_thresh=None, size_thresh_upper=None):
448
- if model == 'DeepLIIF':
489
+ if model in ['DeepLIIF','DeepLIIFKD']:
449
490
  resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
450
491
  overlay, refined, scoring = compute_final_results(
451
492
  orig, images['Seg'], images.get('Marker'), resolution,
@@ -477,7 +518,7 @@ def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='def
477
518
 
478
519
  def infer_modalities(img, tile_size, model_dir, eager_mode=False,
479
520
  color_dapi=False, color_marker=False, opt=None,
480
- return_seg_intermediate=False, seg_only=False):
521
+ return_seg_intermediate=False, seg_only=False, seg_weights=None):
481
522
  """
482
523
  This function is used to infer modalities for the given image using a trained model.
483
524
  :param img: The input image.
@@ -505,7 +546,8 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
505
546
  color_marker=color_marker,
506
547
  opt=opt,
507
548
  return_seg_intermediate=return_seg_intermediate,
508
- seg_only=seg_only
549
+ seg_only=seg_only,
550
+ seg_weights=seg_weights,
509
551
  )
510
552
 
511
553
  if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
@@ -520,7 +562,7 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
520
562
  return images, None
521
563
 
522
564
 
523
- def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000, color_dapi=False, color_marker=False, seg_intermediate=False, seg_only=False):
565
+ def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000, color_dapi=False, color_marker=False, seg_intermediate=False, seg_only=False, seg_weights=None):
524
566
  """
525
567
  This function infers modalities and segmentation mask for the given WSI image. It
526
568
 
@@ -554,7 +596,7 @@ def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size,
554
596
  region = reader.read(XYWH=region_XYWH, rescale=rescale)
555
597
  img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
556
598
 
557
- region_modalities, region_scoring = infer_modalities(img, tile_size, model_dir, color_dapi=color_dapi, color_marker=color_marker, return_seg_intermediate=seg_intermediate, seg_only=seg_only)
599
+ region_modalities, region_scoring = infer_modalities(img, tile_size, model_dir, color_dapi=color_dapi, color_marker=color_marker, return_seg_intermediate=seg_intermediate, seg_only=seg_only, seg_weights=seg_weights)
558
600
  if region_scoring is not None:
559
601
  if scoring is None:
560
602
  scoring = {
@@ -586,8 +628,6 @@ def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size,
586
628
  with open(os.path.join(results_dir, f'{basename}.json'), 'w') as f:
587
629
  json.dump(scoring, f, indent=2)
588
630
 
589
- javabridge.kill_vm()
590
-
591
631
 
592
632
  def get_wsi_resolution(filename):
593
633
  """
@@ -595,9 +635,6 @@ def get_wsi_resolution(filename):
595
635
  the corresponding tile size to use by default for DeepLIIF.
596
636
  If it cannot be found, return (None, None) instead.
597
637
 
598
- Note: This will start the javabridge VM, but not kill it.
599
- It must be killed elsewhere.
600
-
601
638
  Parameters
602
639
  ----------
603
640
  filename : str
@@ -611,11 +648,10 @@ def get_wsi_resolution(filename):
611
648
  Corresponding tile size for DeepLIIF.
612
649
  """
613
650
 
614
- # make sure javabridge is already set up from with call to get_information()
615
- size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
651
+ init_javabridge_bioformats()
652
+ metadata = bioformats.get_omexml_metadata(filename)
616
653
 
617
654
  mag = None
618
- metadata = bioformats.get_omexml_metadata(filename)
619
655
  try:
620
656
  omexml = bioformats.OMEXML(metadata)
621
657
  mag = omexml.instrument().Objective.NominalMagnification
@@ -648,7 +684,7 @@ def get_wsi_resolution(filename):
648
684
  return None, None
649
685
 
650
686
 
651
- def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False):
687
+ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False, seg_weights=None):
652
688
  """
653
689
  Perform inference on a slide and get the results individual cell data.
654
690
 
@@ -666,6 +702,9 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
666
702
  Version of cell data to return (3 or 4).
667
703
  print_log : bool
668
704
  Whether or not to print updates while processing.
705
+ seg_weights : list | tuple
706
+ Optional five weights to use for creating the combined Seg image.
707
+ If None, then the default weights are used.
669
708
 
670
709
  Returns
671
710
  -------
@@ -679,22 +718,21 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
679
718
 
680
719
  resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
681
720
 
682
- size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
683
- rescale = (pixel_type != 'uint8')
684
- print_info('Info:', size_x, size_y, size_z, size_c, size_t, pixel_type)
685
-
686
- num_regions_x = math.ceil(size_x / region_size)
687
- num_regions_y = math.ceil(size_y / region_size)
688
- stride_x = math.ceil(size_x / num_regions_x)
689
- stride_y = math.ceil(size_y / num_regions_y)
690
- print_info('Strides:', stride_x, stride_y)
691
-
692
721
  data = None
693
722
  default_marker_thresh, count_marker_thresh = 0, 0
694
723
  default_size_thresh, count_size_thresh = 0, 0
695
724
 
696
- # javabridge already set up from previous call to get_information()
697
- with bioformats.ImageReader(filename) as reader:
725
+ with WSIReader(filename) as reader:
726
+ size_x = reader.width
727
+ size_y = reader.height
728
+ print_info('Info:', size_x, size_y)
729
+
730
+ num_regions_x = math.ceil(size_x / region_size)
731
+ num_regions_y = math.ceil(size_y / region_size)
732
+ stride_x = math.ceil(size_x / num_regions_x)
733
+ stride_y = math.ceil(size_y / num_regions_y)
734
+ print_info('Strides:', stride_x, stride_y)
735
+
698
736
  start_x, start_y = 0, 0
699
737
 
700
738
  while start_y < size_y:
@@ -702,10 +740,11 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
702
740
  region_XYWH = (start_x, start_y, min(stride_x, size_x-start_x), min(stride_y, size_y-start_y))
703
741
  print_info('Region:', region_XYWH)
704
742
 
705
- region = reader.read(XYWH=region_XYWH, rescale=rescale)
743
+ region = reader.read(region_XYWH)
706
744
  print_info(region.shape, region.dtype)
707
- img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
745
+ img = Image.fromarray(region)
708
746
  print_info(img.size, img.mode)
747
+ del region
709
748
 
710
749
  images = inference(
711
750
  img,
@@ -718,8 +757,17 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
718
757
  opt=None,
719
758
  return_seg_intermediate=False,
720
759
  seg_only=True,
760
+ seg_weights=seg_weights,
721
761
  )
722
- region_data = compute_cell_results(images['Seg'], images.get('Marker'), resolution, version=version)
762
+ del img
763
+
764
+ seg = to_array(images['Seg'])
765
+ del images['Seg']
766
+ marker = to_array(images['Marker'], True) if 'Marker' in images else None
767
+ del images
768
+ region_data = compute_cell_results(seg, marker, resolution, version=version)
769
+ del seg
770
+ del marker
723
771
 
724
772
  if start_x != 0 or start_y != 0:
725
773
  for i in range(len(region_data['cells'])):
@@ -748,8 +796,6 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
748
796
  start_x = 0
749
797
  start_y += stride_y
750
798
 
751
- javabridge.kill_vm()
752
-
753
799
  if count_marker_thresh == 0:
754
800
  count_marker_thresh = 1
755
801
  if count_size_thresh == 0:
@@ -757,4 +803,18 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
757
803
  data['settings']['default_marker_thresh'] = round(default_marker_thresh / count_marker_thresh)
758
804
  data['settings']['default_size_thresh'] = round(default_size_thresh / count_size_thresh)
759
805
 
806
+ data['settings']['tile_size'] = tile_size
807
+ data['settings']['region_size'] = region_size
808
+ data['settings']['seg_weights'] = seg_weights
809
+
810
+ try:
811
+ data['deepliifVersion'] = importlib.metadata.version('deepliif')
812
+ except Exception as e:
813
+ data['deepliifVersion'] = 'unknown'
814
+
815
+ try:
816
+ data['modelVersion'] = pathlib.PurePath(model_dir).name
817
+ except Exception as e:
818
+ data['modelVersion'] = 'unknown'
819
+
760
820
  return data
@@ -147,7 +147,7 @@ class BaseModel(ABC):
147
147
  if isinstance(name, str):
148
148
  if not hasattr(self, name):
149
149
  if len(name.split('_')) != 2:
150
- if self.opt.model == 'DeepLIIF':
150
+ if self.opt.model in ['DeepLIIF','DeepLIIFKD']:
151
151
  img_name = name[:-1] + '_' + name[-1]
152
152
  visual_ret[name] = getattr(self, img_name)
153
153
  else:
@@ -172,12 +172,14 @@ def define_G(
172
172
  norm_layer = get_norm_layer(norm_type=norm)
173
173
  use_spectral_norm = norm == 'spectral'
174
174
 
175
- if netG == 'resnet_9blocks':
176
- net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9,
177
- padding_type=padding_type, upsample=upsample, use_spectral_norm=use_spectral_norm)
178
- elif netG == 'resnet_6blocks':
179
- net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6,
175
+ if netG.startswith('resnet_'):
176
+ n_blocks = int(netG.split('_')[1].replace('blocks',''))
177
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=n_blocks,
180
178
  padding_type=padding_type, upsample=upsample, use_spectral_norm=use_spectral_norm)
179
+ elif netG == 'unet_32':
180
+ net = UnetGenerator(input_nc, output_nc, 5, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
181
+ elif netG == 'unet_64':
182
+ net = UnetGenerator(input_nc, output_nc, 6, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
181
183
  elif netG == 'unet_128':
182
184
  net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
183
185
  elif netG == 'unet_256':
@@ -184,32 +184,38 @@ def mark_background(mask):
184
184
  After the function executes, the pixels will be labeled as background, positive, negative, or unknown.
185
185
  """
186
186
 
187
- seeds = []
188
187
  for i in range(mask.shape[0]):
189
188
  if mask[i, 0] == LABEL_UNKNOWN:
190
- seeds.append((i, 0))
189
+ mask[i, 0] = LABEL_BACKGROUND
191
190
  if mask[i, mask.shape[1]-1] == LABEL_UNKNOWN:
192
- seeds.append((i, mask.shape[1]-1))
191
+ mask[i, mask.shape[1]-1] = LABEL_BACKGROUND
193
192
  for j in range(mask.shape[1]):
194
193
  if mask[0, j] == LABEL_UNKNOWN:
195
- seeds.append((0, j))
194
+ mask[0, j] = LABEL_BACKGROUND
196
195
  if mask[mask.shape[0]-1, j] == LABEL_UNKNOWN:
197
- seeds.append((mask.shape[0]-1, j))
198
-
199
- neighbors = [(-1, 0), (1, 0), (0, -1), (0, 1)]
200
-
201
- while len(seeds) > 0:
202
- seed = seeds.pop()
203
- if mask[seed] == LABEL_UNKNOWN:
204
- mask[seed] = LABEL_BACKGROUND
205
- for n in neighbors:
206
- idx = (seed[0] + n[0], seed[1] + n[1])
207
- if in_bounds(mask, idx) and mask[idx] == LABEL_UNKNOWN:
208
- seeds.append(idx)
196
+ mask[mask.shape[0]-1, j] = LABEL_BACKGROUND
197
+
198
+ count = 1
199
+ while count > 0:
200
+ count = 0
201
+ for i in range(mask.shape[0]):
202
+ for j in range(mask.shape[1]):
203
+ if mask[i, j] == LABEL_UNKNOWN:
204
+ if (mask[i-1, j] == LABEL_BACKGROUND or mask[i+1, j] == LABEL_BACKGROUND or
205
+ mask[i, j-1] == LABEL_BACKGROUND or mask[i, j+1] == LABEL_BACKGROUND):
206
+ mask[i, j] = LABEL_BACKGROUND
207
+ count += 1
208
+ if count > 0:
209
+ for i in range(mask.shape[0]-1, -1, -1):
210
+ for j in range(mask.shape[1]-1, -1, -1):
211
+ if mask[i, j] == LABEL_UNKNOWN:
212
+ if (mask[i-1, j] == LABEL_BACKGROUND or mask[i+1, j] == LABEL_BACKGROUND or
213
+ mask[i, j-1] == LABEL_BACKGROUND or mask[i, j+1] == LABEL_BACKGROUND):
214
+ mask[i, j] = LABEL_BACKGROUND
209
215
 
210
216
 
211
217
  @jit(nopython=True)
212
- def compute_cell_mapping(mask, marker, noise_thresh):
218
+ def compute_cell_mapping(mask, marker, noise_thresh, large_noise_thresh):
213
219
  """
214
220
  Compute the mapping from mask to positive and negative cells.
215
221
 
@@ -264,7 +270,7 @@ def compute_cell_mapping(mask, marker, noise_thresh):
264
270
  center_x += idx[1]
265
271
  count += 1
266
272
 
267
- if count > noise_thresh:
273
+ if count > noise_thresh and (large_noise_thresh is None or count < large_noise_thresh):
268
274
  center_y = int(round(center_y / count))
269
275
  center_x = int(round(center_x / count))
270
276
  positive = True if count_positive >= count_negative else False
@@ -273,7 +279,7 @@ def compute_cell_mapping(mask, marker, noise_thresh):
273
279
  return cells
274
280
 
275
281
 
276
- def get_cells_info(seg, marker, resolution, noise_thresh, seg_thresh):
282
+ def get_cells_info(seg, marker, resolution, noise_thresh, seg_thresh, large_noise_thresh):
277
283
  """
278
284
  Find all cells in the segmentation image that are larger than the noise threshold.
279
285
 
@@ -289,6 +295,8 @@ def get_cells_info(seg, marker, resolution, noise_thresh, seg_thresh):
289
295
  Threshold for tiny noise to ignore (include only cells larger than this value).
290
296
  seg_thresh : int
291
297
  Threshold to use in determining if a pixel should be labeled as positive/negative.
298
+ large_noise_thresh : int | None
299
+ Threshold for large noise to ignore (include only cells smaller than this value).
292
300
 
293
301
  Returns
294
302
  -------
@@ -303,9 +311,10 @@ def get_cells_info(seg, marker, resolution, noise_thresh, seg_thresh):
303
311
  seg = to_array(seg)
304
312
  if marker is not None:
305
313
  marker = to_array(marker, True)
314
+
306
315
  mask = create_posneg_mask(seg, seg_thresh)
307
316
  mark_background(mask)
308
- cellsinfo = compute_cell_mapping(mask, marker, noise_thresh)
317
+ cellsinfo = compute_cell_mapping(mask, marker, noise_thresh, large_noise_thresh)
309
318
 
310
319
  defaults = {}
311
320
  sizes = np.zeros(len(cellsinfo), dtype=np.int64)
@@ -1040,9 +1049,21 @@ def fill_cells(mask):
1040
1049
  mask[y, x] = LABEL_NEGATIVE
1041
1050
 
1042
1051
 
1052
+ def calculate_large_noise_thresh(large_noise_thresh, resolution):
1053
+ if large_noise_thresh != 'default':
1054
+ return large_noise_thresh
1055
+ if resolution == '10x':
1056
+ return 250
1057
+ elif resolution == '20x':
1058
+ return 1000
1059
+ else: # 40x
1060
+ return 4000
1061
+
1062
+
1043
1063
  def compute_cell_results(seg, marker, resolution, version=3,
1044
1064
  seg_thresh=DEFAULT_SEG_THRESH,
1045
- noise_thresh=DEFAULT_NOISE_THRESH):
1065
+ noise_thresh=DEFAULT_NOISE_THRESH,
1066
+ large_noise_thresh='default'):
1046
1067
  """
1047
1068
  Perform postprocessing to compute individual cell results.
1048
1069
 
@@ -1060,6 +1081,9 @@ def compute_cell_results(seg, marker, resolution, version=3,
1060
1081
  Threshold to use in determining if a pixel should be labeled as positive/negative.
1061
1082
  noise_thresh : int
1062
1083
  Threshold for tiny noise to ignore (include only cells larger than this value).
1084
+ large_noise_thresh : int | string | None
1085
+ Threshold for large noise to ignore (include only cells smaller than this value).
1086
+ Valid arguments can be an integer value, the string value 'default', or None.
1063
1087
 
1064
1088
  Returns
1065
1089
  -------
@@ -1071,7 +1095,8 @@ def compute_cell_results(seg, marker, resolution, version=3,
1071
1095
  warnings.warn('Invalid cell data version provided, defaulting to version 3.')
1072
1096
  version = 3
1073
1097
 
1074
- mask, cellsinfo, defaults = get_cells_info(seg, marker, resolution, noise_thresh, seg_thresh)
1098
+ large_noise_thresh = calculate_large_noise_thresh(large_noise_thresh, resolution)
1099
+ mask, cellsinfo, defaults = get_cells_info(seg, marker, resolution, noise_thresh, seg_thresh, large_noise_thresh)
1075
1100
 
1076
1101
  cells = []
1077
1102
  for cell in cellsinfo:
@@ -1094,6 +1119,7 @@ def compute_cell_results(seg, marker, resolution, version=3,
1094
1119
  'default_marker_thresh': defaults['marker_thresh'] if 'marker_thresh' in defaults else None,
1095
1120
  'default_size_thresh': defaults['size_thresh'],
1096
1121
  'noise_thresh': noise_thresh,
1122
+ 'large_noise_thresh': large_noise_thresh,
1097
1123
  'seg_thresh': seg_thresh,
1098
1124
  },
1099
1125
  'dataVersion': version,
@@ -1107,7 +1133,8 @@ def compute_final_results(orig, seg, marker, resolution,
1107
1133
  marker_thresh=None,
1108
1134
  size_thresh_upper=None,
1109
1135
  seg_thresh=DEFAULT_SEG_THRESH,
1110
- noise_thresh=DEFAULT_NOISE_THRESH):
1136
+ noise_thresh=DEFAULT_NOISE_THRESH,
1137
+ large_noise_thresh='default'):
1111
1138
  """
1112
1139
  Perform postprocessing to compute final count and image results.
1113
1140
 
@@ -1131,6 +1158,9 @@ def compute_final_results(orig, seg, marker, resolution,
1131
1158
  Threshold to use in determining if a pixel should be labeled as positive/negative.
1132
1159
  noise_thresh : int
1133
1160
  Threshold for tiny noise to ignore (include only cells larger than this value).
1161
+ large_noise_thresh : int | string | None
1162
+ Threshold for large noise to ignore (include only cells smaller than this value).
1163
+ Valid arguments can be an integer value, the string value 'default', or None.
1134
1164
 
1135
1165
  Returns
1136
1166
  -------
@@ -1142,7 +1172,8 @@ def compute_final_results(orig, seg, marker, resolution,
1142
1172
  Dictionary with scoring and settings information.
1143
1173
  """
1144
1174
 
1145
- mask, cellsinfo, defaults = get_cells_info(seg, marker, resolution, noise_thresh, seg_thresh)
1175
+ large_noise_thresh = calculate_large_noise_thresh(large_noise_thresh, resolution)
1176
+ mask, cellsinfo, defaults = get_cells_info(seg, marker, resolution, noise_thresh, seg_thresh, large_noise_thresh)
1146
1177
 
1147
1178
  if size_thresh is None:
1148
1179
  size_thresh = 0