deepliif 1.1.14__py3-none-any.whl → 1.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepliif/models/__init__ - multiprocessing (failure).py +980 -0
- deepliif/models/__init__.py +54 -51
- deepliif/util/__init__.py +103 -16
- deepliif/util/util - modified tensor_to_pil.py +255 -0
- {deepliif-1.1.14.dist-info → deepliif-1.1.15.dist-info}/METADATA +628 -622
- {deepliif-1.1.14.dist-info → deepliif-1.1.15.dist-info}/RECORD +10 -8
- {deepliif-1.1.14.dist-info → deepliif-1.1.15.dist-info}/WHEEL +5 -5
- {deepliif-1.1.14.dist-info → deepliif-1.1.15.dist-info}/LICENSE.md +0 -0
- {deepliif-1.1.14.dist-info → deepliif-1.1.15.dist-info}/entry_points.txt +0 -0
- {deepliif-1.1.14.dist-info → deepliif-1.1.15.dist-info}/top_level.txt +0 -0
deepliif/models/__init__.py
CHANGED
|
@@ -219,13 +219,14 @@ def compute_overlap(img_size, tile_size):
|
|
|
219
219
|
return tile_size // 4
|
|
220
220
|
|
|
221
221
|
|
|
222
|
-
def run_torchserve(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, use_dask=True, output_tensor=False):
|
|
222
|
+
def run_torchserve(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, seg_weights=None, use_dask=True, output_tensor=False):
|
|
223
223
|
"""
|
|
224
224
|
eager_mode: not used in this function; put in place to be consistent with run_dask
|
|
225
225
|
so that run_wrapper() could call either this function or run_dask with
|
|
226
226
|
same syntax
|
|
227
227
|
opt: same as eager_mode
|
|
228
228
|
seg_only: same as eager_mode
|
|
229
|
+
seg_weights: same as eager_mode
|
|
229
230
|
nets: same as eager_mode
|
|
230
231
|
"""
|
|
231
232
|
buffer = BytesIO()
|
|
@@ -245,7 +246,7 @@ def run_torchserve(img, model_path=None, nets=None, eager_mode=False, opt=None,
|
|
|
245
246
|
return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
|
|
246
247
|
|
|
247
248
|
|
|
248
|
-
def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, use_dask=True, output_tensor=False):
|
|
249
|
+
def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, seg_weights=None, use_dask=True, output_tensor=False):
|
|
249
250
|
"""
|
|
250
251
|
Provide either the model path or the networks object.
|
|
251
252
|
|
|
@@ -280,20 +281,22 @@ def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_on
|
|
|
280
281
|
return model(input.to(next(model.parameters()).device))
|
|
281
282
|
|
|
282
283
|
if opt.model in ['DeepLIIF','DeepLIIFKD']:
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
284
|
+
if seg_weights is None:
|
|
285
|
+
weights = {
|
|
286
|
+
'G51': 0.25, # IHC
|
|
287
|
+
'G52': 0.25, # Hema
|
|
288
|
+
'G53': 0.25, # DAPI
|
|
289
|
+
'G54': 0.00, # Lap2
|
|
290
|
+
'G55': 0.25, # Marker
|
|
291
|
+
}
|
|
292
|
+
else:
|
|
293
|
+
weights = {
|
|
294
|
+
'G51': seg_weights[0], # IHC
|
|
295
|
+
'G52': seg_weights[1], # Hema
|
|
296
|
+
'G53': seg_weights[2], # DAPI
|
|
297
|
+
'G54': seg_weights[3], # Lap2
|
|
298
|
+
'G55': seg_weights[4], # Marker
|
|
299
|
+
}
|
|
297
300
|
|
|
298
301
|
seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
|
|
299
302
|
if seg_only:
|
|
@@ -357,12 +360,12 @@ def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_on
|
|
|
357
360
|
def is_empty(tile):
|
|
358
361
|
thresh = 15
|
|
359
362
|
if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
|
|
360
|
-
return all([True if
|
|
363
|
+
return all([True if image_variance_gray(t) < thresh else False for t in tile])
|
|
361
364
|
else:
|
|
362
|
-
return True if
|
|
365
|
+
return True if image_variance_gray(tile) < thresh else False
|
|
363
366
|
|
|
364
367
|
|
|
365
|
-
def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, use_dask=True, output_tensor=False):
|
|
368
|
+
def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, seg_weights=None, use_dask=True, output_tensor=False):
|
|
366
369
|
if opt.model in ['DeepLIIF','DeepLIIFKD']:
|
|
367
370
|
if is_empty(tile):
|
|
368
371
|
if seg_only:
|
|
@@ -384,7 +387,7 @@ def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt=
|
|
|
384
387
|
'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
385
388
|
}
|
|
386
389
|
else:
|
|
387
|
-
return run_fn(tile, model_path, None, eager_mode, opt, seg_only)
|
|
390
|
+
return run_fn(tile, model_path, None, eager_mode, opt, seg_only, seg_weights)
|
|
388
391
|
elif opt.model in ['DeepLIIFExt', 'SDG']:
|
|
389
392
|
if is_empty(tile):
|
|
390
393
|
res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
|
|
@@ -405,7 +408,7 @@ def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt=
|
|
|
405
408
|
|
|
406
409
|
def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
|
|
407
410
|
eager_mode=False, color_dapi=False, color_marker=False, opt=None,
|
|
408
|
-
return_seg_intermediate=False, seg_only=False, opt_args={}):
|
|
411
|
+
return_seg_intermediate=False, seg_only=False, seg_weights=None, opt_args={}):
|
|
409
412
|
"""
|
|
410
413
|
opt_args: a dictionary of key and values to add/overwrite to opt
|
|
411
414
|
"""
|
|
@@ -430,7 +433,7 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
|
|
|
430
433
|
|
|
431
434
|
tiler = InferenceTiler(orig, tile_size, overlap_size)
|
|
432
435
|
for tile in tiler:
|
|
433
|
-
tiler.stitch(run_wrapper(tile, run_fn, model_path, None, eager_mode, opt, seg_only))
|
|
436
|
+
tiler.stitch(run_wrapper(tile, run_fn, model_path, None, eager_mode, opt, seg_only, seg_weights))
|
|
434
437
|
|
|
435
438
|
results = tiler.results()
|
|
436
439
|
|
|
@@ -515,7 +518,7 @@ def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='def
|
|
|
515
518
|
|
|
516
519
|
def infer_modalities(img, tile_size, model_dir, eager_mode=False,
|
|
517
520
|
color_dapi=False, color_marker=False, opt=None,
|
|
518
|
-
return_seg_intermediate=False, seg_only=False):
|
|
521
|
+
return_seg_intermediate=False, seg_only=False, seg_weights=None):
|
|
519
522
|
"""
|
|
520
523
|
This function is used to infer modalities for the given image using a trained model.
|
|
521
524
|
:param img: The input image.
|
|
@@ -543,7 +546,8 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
|
|
|
543
546
|
color_marker=color_marker,
|
|
544
547
|
opt=opt,
|
|
545
548
|
return_seg_intermediate=return_seg_intermediate,
|
|
546
|
-
seg_only=seg_only
|
|
549
|
+
seg_only=seg_only,
|
|
550
|
+
seg_weights=seg_weights,
|
|
547
551
|
)
|
|
548
552
|
|
|
549
553
|
if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
|
|
@@ -558,7 +562,7 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
|
|
|
558
562
|
return images, None
|
|
559
563
|
|
|
560
564
|
|
|
561
|
-
def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000, color_dapi=False, color_marker=False, seg_intermediate=False, seg_only=False):
|
|
565
|
+
def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000, color_dapi=False, color_marker=False, seg_intermediate=False, seg_only=False, seg_weights=None):
|
|
562
566
|
"""
|
|
563
567
|
This function infers modalities and segmentation mask for the given WSI image. It
|
|
564
568
|
|
|
@@ -592,7 +596,7 @@ def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size,
|
|
|
592
596
|
region = reader.read(XYWH=region_XYWH, rescale=rescale)
|
|
593
597
|
img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
|
|
594
598
|
|
|
595
|
-
region_modalities, region_scoring = infer_modalities(img, tile_size, model_dir, color_dapi=color_dapi, color_marker=color_marker, return_seg_intermediate=seg_intermediate, seg_only=seg_only)
|
|
599
|
+
region_modalities, region_scoring = infer_modalities(img, tile_size, model_dir, color_dapi=color_dapi, color_marker=color_marker, return_seg_intermediate=seg_intermediate, seg_only=seg_only, seg_weights=seg_weights)
|
|
596
600
|
if region_scoring is not None:
|
|
597
601
|
if scoring is None:
|
|
598
602
|
scoring = {
|
|
@@ -624,8 +628,6 @@ def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size,
|
|
|
624
628
|
with open(os.path.join(results_dir, f'{basename}.json'), 'w') as f:
|
|
625
629
|
json.dump(scoring, f, indent=2)
|
|
626
630
|
|
|
627
|
-
javabridge.kill_vm()
|
|
628
|
-
|
|
629
631
|
|
|
630
632
|
def get_wsi_resolution(filename):
|
|
631
633
|
"""
|
|
@@ -633,9 +635,6 @@ def get_wsi_resolution(filename):
|
|
|
633
635
|
the corresponding tile size to use by default for DeepLIIF.
|
|
634
636
|
If it cannot be found, return (None, None) instead.
|
|
635
637
|
|
|
636
|
-
Note: This will start the javabridge VM, but not kill it.
|
|
637
|
-
It must be killed elsewhere.
|
|
638
|
-
|
|
639
638
|
Parameters
|
|
640
639
|
----------
|
|
641
640
|
filename : str
|
|
@@ -649,11 +648,10 @@ def get_wsi_resolution(filename):
|
|
|
649
648
|
Corresponding tile size for DeepLIIF.
|
|
650
649
|
"""
|
|
651
650
|
|
|
652
|
-
|
|
653
|
-
|
|
651
|
+
init_javabridge_bioformats()
|
|
652
|
+
metadata = bioformats.get_omexml_metadata(filename)
|
|
654
653
|
|
|
655
654
|
mag = None
|
|
656
|
-
metadata = bioformats.get_omexml_metadata(filename)
|
|
657
655
|
try:
|
|
658
656
|
omexml = bioformats.OMEXML(metadata)
|
|
659
657
|
mag = omexml.instrument().Objective.NominalMagnification
|
|
@@ -686,7 +684,7 @@ def get_wsi_resolution(filename):
|
|
|
686
684
|
return None, None
|
|
687
685
|
|
|
688
686
|
|
|
689
|
-
def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False):
|
|
687
|
+
def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False, seg_weights=None):
|
|
690
688
|
"""
|
|
691
689
|
Perform inference on a slide and get the results individual cell data.
|
|
692
690
|
|
|
@@ -704,6 +702,9 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
|
|
|
704
702
|
Version of cell data to return (3 or 4).
|
|
705
703
|
print_log : bool
|
|
706
704
|
Whether or not to print updates while processing.
|
|
705
|
+
seg_weights : list | tuple
|
|
706
|
+
Optional five weights to use for creating the combined Seg image.
|
|
707
|
+
If None, then the default weights are used.
|
|
707
708
|
|
|
708
709
|
Returns
|
|
709
710
|
-------
|
|
@@ -717,22 +718,21 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
|
|
|
717
718
|
|
|
718
719
|
resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
|
|
719
720
|
|
|
720
|
-
size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
|
|
721
|
-
rescale = (pixel_type != 'uint8')
|
|
722
|
-
print_info('Info:', size_x, size_y, size_z, size_c, size_t, pixel_type)
|
|
723
|
-
|
|
724
|
-
num_regions_x = math.ceil(size_x / region_size)
|
|
725
|
-
num_regions_y = math.ceil(size_y / region_size)
|
|
726
|
-
stride_x = math.ceil(size_x / num_regions_x)
|
|
727
|
-
stride_y = math.ceil(size_y / num_regions_y)
|
|
728
|
-
print_info('Strides:', stride_x, stride_y)
|
|
729
|
-
|
|
730
721
|
data = None
|
|
731
722
|
default_marker_thresh, count_marker_thresh = 0, 0
|
|
732
723
|
default_size_thresh, count_size_thresh = 0, 0
|
|
733
724
|
|
|
734
|
-
|
|
735
|
-
|
|
725
|
+
with WSIReader(filename) as reader:
|
|
726
|
+
size_x = reader.width
|
|
727
|
+
size_y = reader.height
|
|
728
|
+
print_info('Info:', size_x, size_y)
|
|
729
|
+
|
|
730
|
+
num_regions_x = math.ceil(size_x / region_size)
|
|
731
|
+
num_regions_y = math.ceil(size_y / region_size)
|
|
732
|
+
stride_x = math.ceil(size_x / num_regions_x)
|
|
733
|
+
stride_y = math.ceil(size_y / num_regions_y)
|
|
734
|
+
print_info('Strides:', stride_x, stride_y)
|
|
735
|
+
|
|
736
736
|
start_x, start_y = 0, 0
|
|
737
737
|
|
|
738
738
|
while start_y < size_y:
|
|
@@ -740,9 +740,9 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
|
|
|
740
740
|
region_XYWH = (start_x, start_y, min(stride_x, size_x-start_x), min(stride_y, size_y-start_y))
|
|
741
741
|
print_info('Region:', region_XYWH)
|
|
742
742
|
|
|
743
|
-
region = reader.read(
|
|
743
|
+
region = reader.read(region_XYWH)
|
|
744
744
|
print_info(region.shape, region.dtype)
|
|
745
|
-
img = Image.fromarray(
|
|
745
|
+
img = Image.fromarray(region)
|
|
746
746
|
print_info(img.size, img.mode)
|
|
747
747
|
del region
|
|
748
748
|
|
|
@@ -757,6 +757,7 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
|
|
|
757
757
|
opt=None,
|
|
758
758
|
return_seg_intermediate=False,
|
|
759
759
|
seg_only=True,
|
|
760
|
+
seg_weights=seg_weights,
|
|
760
761
|
)
|
|
761
762
|
del img
|
|
762
763
|
|
|
@@ -795,8 +796,6 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
|
|
|
795
796
|
start_x = 0
|
|
796
797
|
start_y += stride_y
|
|
797
798
|
|
|
798
|
-
javabridge.kill_vm()
|
|
799
|
-
|
|
800
799
|
if count_marker_thresh == 0:
|
|
801
800
|
count_marker_thresh = 1
|
|
802
801
|
if count_size_thresh == 0:
|
|
@@ -804,6 +803,10 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
|
|
|
804
803
|
data['settings']['default_marker_thresh'] = round(default_marker_thresh / count_marker_thresh)
|
|
805
804
|
data['settings']['default_size_thresh'] = round(default_size_thresh / count_size_thresh)
|
|
806
805
|
|
|
806
|
+
data['settings']['tile_size'] = tile_size
|
|
807
|
+
data['settings']['region_size'] = region_size
|
|
808
|
+
data['settings']['seg_weights'] = seg_weights
|
|
809
|
+
|
|
807
810
|
try:
|
|
808
811
|
data['deepliifVersion'] = importlib.metadata.version('deepliif')
|
|
809
812
|
except Exception as e:
|
deepliif/util/__init__.py
CHANGED
|
@@ -2,6 +2,10 @@
|
|
|
2
2
|
import os
|
|
3
3
|
import collections
|
|
4
4
|
|
|
5
|
+
import atexit
|
|
6
|
+
import functools
|
|
7
|
+
import threading
|
|
8
|
+
|
|
5
9
|
import torch
|
|
6
10
|
import numpy as np
|
|
7
11
|
from PIL import Image, ImageOps
|
|
@@ -22,6 +26,9 @@ import javabridge
|
|
|
22
26
|
import bioformats.omexml as ome
|
|
23
27
|
import tifffile as tf
|
|
24
28
|
|
|
29
|
+
from tifffile import TiffFile
|
|
30
|
+
import zarr
|
|
31
|
+
|
|
25
32
|
|
|
26
33
|
excluding_names = ['Hema', 'DAPI', 'DAPILap2', 'Ki67', 'Seg', 'Marked', 'SegRefined', 'SegOverlaid', 'Marker', 'Lap2']
|
|
27
34
|
# Image extensions to consider
|
|
@@ -392,6 +399,30 @@ def image_variance_rgb(img):
|
|
|
392
399
|
return var
|
|
393
400
|
|
|
394
401
|
|
|
402
|
+
def init_javabridge_bioformats():
|
|
403
|
+
"""
|
|
404
|
+
Initialize javabridge for use with bioformats.
|
|
405
|
+
Run as daemon so no need to explicitly call kill_vm.
|
|
406
|
+
This function will only run once; repeat calls do nothing.
|
|
407
|
+
"""
|
|
408
|
+
|
|
409
|
+
if not hasattr(init_javabridge_bioformats, 'called'):
|
|
410
|
+
# https://github.com/LeeKamentsky/python-javabridge/issues/155
|
|
411
|
+
old_init = threading.Thread.__init__
|
|
412
|
+
threading.Thread.__init__ = functools.partialmethod(old_init, daemon=True)
|
|
413
|
+
javabridge.start_vm(class_path=bioformats.JARS)
|
|
414
|
+
threading.Thread.__init__ = old_init
|
|
415
|
+
atexit.register(javabridge.kill_vm)
|
|
416
|
+
|
|
417
|
+
rootLoggerName = javabridge.get_static_field("org/slf4j/Logger", "ROOT_LOGGER_NAME", "Ljava/lang/String;")
|
|
418
|
+
rootLogger = javabridge.static_call("org/slf4j/LoggerFactory", "getLogger",
|
|
419
|
+
"(Ljava/lang/String;)Lorg/slf4j/Logger;", rootLoggerName)
|
|
420
|
+
logLevel = javabridge.get_static_field("ch/qos/logback/classic/Level", "WARN", "Lch/qos/logback/classic/Level;")
|
|
421
|
+
javabridge.call(rootLogger, "setLevel", "(Lch/qos/logback/classic/Level;)V", logLevel)
|
|
422
|
+
|
|
423
|
+
init_javabridge_bioformats.called = True
|
|
424
|
+
|
|
425
|
+
|
|
395
426
|
def read_bioformats_image_with_reader(path, channel=0, region=(0, 0, 0, 0)):
|
|
396
427
|
"""
|
|
397
428
|
Using this function, you can read a specific region of a large image by giving the region bounding box (XYWH format)
|
|
@@ -402,14 +433,7 @@ def read_bioformats_image_with_reader(path, channel=0, region=(0, 0, 0, 0)):
|
|
|
402
433
|
:param region: The bounding box around the region of interest (XYWH format).
|
|
403
434
|
:return: The specified region of interest image (numpy array).
|
|
404
435
|
"""
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
rootLoggerName = javabridge.get_static_field("org/slf4j/Logger", "ROOT_LOGGER_NAME", "Ljava/lang/String;")
|
|
408
|
-
rootLogger = javabridge.static_call("org/slf4j/LoggerFactory", "getLogger",
|
|
409
|
-
"(Ljava/lang/String;)Lorg/slf4j/Logger;", rootLoggerName)
|
|
410
|
-
logLevel = javabridge.get_static_field("ch/qos/logback/classic/Level", "WARN", "Lch/qos/logback/classic/Level;")
|
|
411
|
-
javabridge.call(rootLogger, "setLevel", "(Lch/qos/logback/classic/Level;)V", logLevel)
|
|
412
|
-
|
|
436
|
+
init_javabridge_bioformats()
|
|
413
437
|
with bioformats.ImageReader(path) as reader:
|
|
414
438
|
return reader.read(t=channel, XYWH=region)
|
|
415
439
|
|
|
@@ -421,14 +445,7 @@ def get_information(filename):
|
|
|
421
445
|
:param filename: The address to the ome image.
|
|
422
446
|
:return: size_x, size_y, size_z, size_c, size_t, pixel_type
|
|
423
447
|
"""
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
rootLoggerName = javabridge.get_static_field("org/slf4j/Logger", "ROOT_LOGGER_NAME", "Ljava/lang/String;")
|
|
427
|
-
rootLogger = javabridge.static_call("org/slf4j/LoggerFactory", "getLogger",
|
|
428
|
-
"(Ljava/lang/String;)Lorg/slf4j/Logger;", rootLoggerName)
|
|
429
|
-
logLevel = javabridge.get_static_field("ch/qos/logback/classic/Level", "WARN", "Lch/qos/logback/classic/Level;")
|
|
430
|
-
javabridge.call(rootLogger, "setLevel", "(Lch/qos/logback/classic/Level;)V", logLevel)
|
|
431
|
-
|
|
448
|
+
init_javabridge_bioformats()
|
|
432
449
|
metadata = bioformats.get_omexml_metadata(filename)
|
|
433
450
|
omexml = bioformats.OMEXML(metadata)
|
|
434
451
|
size_x, size_y, size_z, size_c, size_t, pixel_type = omexml.image().Pixels.SizeX, \
|
|
@@ -441,6 +458,76 @@ def get_information(filename):
|
|
|
441
458
|
return size_x, size_y, size_z, size_c, size_t, pixel_type
|
|
442
459
|
|
|
443
460
|
|
|
461
|
+
class WSIReader:
|
|
462
|
+
"""
|
|
463
|
+
Assumes the file is a single image (e.g., not a stacked
|
|
464
|
+
OME TIFF) and will always return uint8 pixel type data.
|
|
465
|
+
"""
|
|
466
|
+
|
|
467
|
+
def __init__(self, path):
|
|
468
|
+
init_javabridge_bioformats()
|
|
469
|
+
metadata = bioformats.get_omexml_metadata(path)
|
|
470
|
+
omexml = bioformats.OMEXML(metadata)
|
|
471
|
+
|
|
472
|
+
self._path = path
|
|
473
|
+
self._width = omexml.image().Pixels.SizeX
|
|
474
|
+
self._height = omexml.image().Pixels.SizeY
|
|
475
|
+
self._pixel_type = omexml.image().Pixels.PixelType
|
|
476
|
+
|
|
477
|
+
self._tif = None
|
|
478
|
+
if self._pixel_type == 'uint8':
|
|
479
|
+
try:
|
|
480
|
+
self._file = None
|
|
481
|
+
self._file = open(path, 'rb')
|
|
482
|
+
self._tif = TiffFile(self._file)
|
|
483
|
+
self._zarr = zarr.open(self._tif.pages[0].aszarr(), mode='r')
|
|
484
|
+
except Exception as e:
|
|
485
|
+
if self._tif is not None:
|
|
486
|
+
self._tif.close()
|
|
487
|
+
self._tif = None
|
|
488
|
+
if self._file is not None:
|
|
489
|
+
self._file.close()
|
|
490
|
+
|
|
491
|
+
self._bfreader = None
|
|
492
|
+
if self._tif is None:
|
|
493
|
+
self._rescale = (self._pixel_type != 'uint8')
|
|
494
|
+
self._bfreader = bioformats.ImageReader(path)
|
|
495
|
+
|
|
496
|
+
if self._tif is None and self._bfreader is None:
|
|
497
|
+
raise Exception('Cannot read WSI file.')
|
|
498
|
+
|
|
499
|
+
def __enter__(self):
|
|
500
|
+
return self
|
|
501
|
+
|
|
502
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
503
|
+
self.close()
|
|
504
|
+
|
|
505
|
+
def close(self):
|
|
506
|
+
if self._tif is not None:
|
|
507
|
+
self._tif.close()
|
|
508
|
+
self._file.close()
|
|
509
|
+
if self._bfreader is not None:
|
|
510
|
+
self._bfreader.close()
|
|
511
|
+
|
|
512
|
+
@property
|
|
513
|
+
def width(self):
|
|
514
|
+
return self._width
|
|
515
|
+
|
|
516
|
+
@property
|
|
517
|
+
def height(self):
|
|
518
|
+
return self._height
|
|
519
|
+
|
|
520
|
+
def read(self, xywh):
|
|
521
|
+
if self._tif is not None:
|
|
522
|
+
x, y, w, h = xywh
|
|
523
|
+
return self._zarr[y:y+h, x:x+w]
|
|
524
|
+
|
|
525
|
+
px = self._bfreader.read(XYWH=xywh, rescale=self._rescale)
|
|
526
|
+
if self._rescale:
|
|
527
|
+
px = (px * 255).astype(np.uint8)
|
|
528
|
+
return px
|
|
529
|
+
|
|
530
|
+
|
|
444
531
|
|
|
445
532
|
|
|
446
533
|
def write_results_to_pickle_file(output_addr, results):
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
"""This module contains simple helper functions """
|
|
2
|
+
import os
|
|
3
|
+
from time import time
|
|
4
|
+
from functools import wraps
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import numpy as np
|
|
8
|
+
from PIL import Image
|
|
9
|
+
import cv2
|
|
10
|
+
from skimage.metrics import structural_similarity as ssim
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def timeit(f):
|
|
14
|
+
@wraps(f)
|
|
15
|
+
def wrap(*args, **kwargs):
|
|
16
|
+
ts = time()
|
|
17
|
+
result = f(*args, **kwargs)
|
|
18
|
+
print(f'{f.__name__} {time() - ts}')
|
|
19
|
+
|
|
20
|
+
return result
|
|
21
|
+
|
|
22
|
+
return wrap
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def diagnose_network(net, name='network'):
|
|
26
|
+
"""Calculate and print the mean of average absolute(gradients)
|
|
27
|
+
|
|
28
|
+
Parameters:
|
|
29
|
+
net (torch network) -- Torch network
|
|
30
|
+
name (str) -- the name of the network
|
|
31
|
+
"""
|
|
32
|
+
mean = 0.0
|
|
33
|
+
count = 0
|
|
34
|
+
for param in net.parameters():
|
|
35
|
+
if param.grad is not None:
|
|
36
|
+
mean += torch.mean(torch.abs(param.grad.data))
|
|
37
|
+
count += 1
|
|
38
|
+
if count > 0:
|
|
39
|
+
mean = mean / count
|
|
40
|
+
print(name)
|
|
41
|
+
print(mean)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def save_image(image_numpy, image_path, aspect_ratio=1.0):
|
|
45
|
+
"""Save a numpy image to the disk
|
|
46
|
+
|
|
47
|
+
Parameters:
|
|
48
|
+
image_numpy (numpy array) -- input numpy array
|
|
49
|
+
image_path (str) -- the path of the image
|
|
50
|
+
"""
|
|
51
|
+
x, y, nc = image_numpy.shape
|
|
52
|
+
|
|
53
|
+
if nc > 3:
|
|
54
|
+
if nc % 3 == 0:
|
|
55
|
+
nc_img = 3
|
|
56
|
+
no_img = nc // nc_img
|
|
57
|
+
|
|
58
|
+
elif nc % 2 == 0:
|
|
59
|
+
nc_img = 2
|
|
60
|
+
no_img = nc // nc_img
|
|
61
|
+
else:
|
|
62
|
+
nc_img = 1
|
|
63
|
+
no_img = nc // nc_img
|
|
64
|
+
print(f'image (numpy) has {nc}>3 channels, inferred to have {no_img} images each with {nc_img} channel(s)')
|
|
65
|
+
l_image_numpy = np.dsplit(image_numpy,[nc_img*i for i in range(1,no_img)])
|
|
66
|
+
image_numpy = np.concatenate(l_image_numpy, axis=1) # stack horizontally
|
|
67
|
+
|
|
68
|
+
image_pil = Image.fromarray(image_numpy)
|
|
69
|
+
h, w, _ = image_numpy.shape
|
|
70
|
+
|
|
71
|
+
if aspect_ratio > 1.0:
|
|
72
|
+
image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
|
|
73
|
+
if aspect_ratio < 1.0:
|
|
74
|
+
image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
|
|
75
|
+
image_pil.save(image_path)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def print_numpy(x, val=True, shp=False):
|
|
79
|
+
"""Print the mean, min, max, median, std, and size of a numpy array
|
|
80
|
+
|
|
81
|
+
Parameters:
|
|
82
|
+
val (bool) -- if print the values of the numpy array
|
|
83
|
+
shp (bool) -- if print the shape of the numpy array
|
|
84
|
+
"""
|
|
85
|
+
x = x.astype(np.float64)
|
|
86
|
+
if shp:
|
|
87
|
+
print('shape,', x.shape)
|
|
88
|
+
if val:
|
|
89
|
+
x = x.flatten()
|
|
90
|
+
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
|
|
91
|
+
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def mkdirs(paths):
|
|
95
|
+
"""create empty directories if they don't exist
|
|
96
|
+
|
|
97
|
+
Parameters:
|
|
98
|
+
paths (str list) -- a list of directory paths
|
|
99
|
+
"""
|
|
100
|
+
if isinstance(paths, list) and not isinstance(paths, str):
|
|
101
|
+
for path in paths:
|
|
102
|
+
mkdir(path)
|
|
103
|
+
else:
|
|
104
|
+
mkdir(paths)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def mkdir(path):
|
|
108
|
+
"""create a single empty directory if it didn't exist
|
|
109
|
+
|
|
110
|
+
Parameters:
|
|
111
|
+
path (str) -- a single directory path
|
|
112
|
+
"""
|
|
113
|
+
if not os.path.exists(path):
|
|
114
|
+
os.makedirs(path, exist_ok=True)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
import time
|
|
118
|
+
time_tensor = 0
|
|
119
|
+
time_convert = 0
|
|
120
|
+
time_transpose = 0
|
|
121
|
+
time_astype = 0
|
|
122
|
+
time_topil = 0
|
|
123
|
+
time_scale = 0
|
|
124
|
+
def print_times():
|
|
125
|
+
print('Time to get tensor:', round(time_tensor, 1), flush=True)
|
|
126
|
+
print('Time to convert:', round(time_convert, 1), flush=True)
|
|
127
|
+
print('Time to transpose:', round(time_transpose, 1), flush=True)
|
|
128
|
+
print('Time to scale:', round(time_scale, 1), flush=True)
|
|
129
|
+
print('Time for astype:', round(time_transpose, 1), flush=True)
|
|
130
|
+
print('Time to pil:', round(time_topil, 1), flush=True)
|
|
131
|
+
|
|
132
|
+
def tensor2im(input_image, imtype=np.uint8):
|
|
133
|
+
""""Converts a Tensor array into a numpy image array.
|
|
134
|
+
|
|
135
|
+
Parameters:
|
|
136
|
+
input_image (tensor) -- the input image tensor array
|
|
137
|
+
imtype (type) -- the desired type of the converted numpy array
|
|
138
|
+
"""
|
|
139
|
+
if not isinstance(input_image, np.ndarray):
|
|
140
|
+
if isinstance(input_image, torch.Tensor): # get the data from a variable
|
|
141
|
+
ts = time.time()
|
|
142
|
+
image_tensor = input_image.data
|
|
143
|
+
te = time.time()
|
|
144
|
+
global time_tensor
|
|
145
|
+
time_tensor += (te - ts)
|
|
146
|
+
else:
|
|
147
|
+
return input_image
|
|
148
|
+
ts = time.time()
|
|
149
|
+
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
|
|
150
|
+
te = time.time()
|
|
151
|
+
global time_convert
|
|
152
|
+
time_convert += (te - ts)
|
|
153
|
+
if image_numpy.shape[0] == 1: # grayscale to RGB
|
|
154
|
+
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
|
155
|
+
ts = time.time()
|
|
156
|
+
#image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
|
|
157
|
+
image_numpy = np.transpose(image_numpy, (1, 2, 0))
|
|
158
|
+
te = time.time()
|
|
159
|
+
global time_transpose
|
|
160
|
+
time_transpose += (te - ts)
|
|
161
|
+
ts = time.time()
|
|
162
|
+
image_numpy = cv2.resize(image_numpy, (256, 256), interpolation=cv2.INTER_AREA)
|
|
163
|
+
image_numpy = (image_numpy + 1) / 2.0 * 255.0
|
|
164
|
+
te = time.time()
|
|
165
|
+
global time_scale
|
|
166
|
+
time_scale += (te - ts)
|
|
167
|
+
else: # if it is a numpy array, do nothing
|
|
168
|
+
image_numpy = input_image
|
|
169
|
+
return image_numpy.astype(imtype)
|
|
170
|
+
ts = time.time()
|
|
171
|
+
image_numpy = image_numpy.astype(imtype)
|
|
172
|
+
te = time.time()
|
|
173
|
+
global time_astype
|
|
174
|
+
time_astype += (te - ts)
|
|
175
|
+
return image_numpy
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def tensor_to_pil(t):
|
|
179
|
+
#return Image.fromarray(tensor2im(t))
|
|
180
|
+
arr = tensor2im(t)
|
|
181
|
+
ts = time.time()
|
|
182
|
+
#arr = cv2.resize(arr, (256, 256), interpolation=cv2.INTER_CUBIC)
|
|
183
|
+
im = Image.fromarray(arr)
|
|
184
|
+
te = time.time()
|
|
185
|
+
global time_topil
|
|
186
|
+
time_topil += (te - ts)
|
|
187
|
+
return im
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def calculate_ssim(img1, img2):
|
|
191
|
+
return ssim(img1, img2, data_range=img2.max() - img2.min())
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def check_multi_scale(img1, img2):
|
|
195
|
+
img1 = np.array(img1)
|
|
196
|
+
img2 = np.array(img2)
|
|
197
|
+
max_ssim = (512, 0)
|
|
198
|
+
for tile_size in range(100, 1000, 100):
|
|
199
|
+
image_ssim = 0
|
|
200
|
+
tile_no = 0
|
|
201
|
+
for i in range(0, img2.shape[0], tile_size):
|
|
202
|
+
for j in range(0, img2.shape[1], tile_size):
|
|
203
|
+
if i + tile_size <= img2.shape[0] and j + tile_size <= img2.shape[1]:
|
|
204
|
+
tile = img2[i: i + tile_size, j: j + tile_size]
|
|
205
|
+
tile = cv2.resize(tile, (img1.shape[0], img1.shape[1]))
|
|
206
|
+
tile_ssim = calculate_ssim(img1, tile)
|
|
207
|
+
image_ssim += tile_ssim
|
|
208
|
+
tile_no += 1
|
|
209
|
+
if tile_no > 0:
|
|
210
|
+
image_ssim /= tile_no
|
|
211
|
+
if max_ssim[1] < image_ssim:
|
|
212
|
+
max_ssim = (tile_size, image_ssim)
|
|
213
|
+
return max_ssim[0]
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
import subprocess
|
|
217
|
+
import os
|
|
218
|
+
from threading import Thread , Timer
|
|
219
|
+
import sched, time
|
|
220
|
+
|
|
221
|
+
# modified from https://stackoverflow.com/questions/67707828/how-to-get-every-seconds-gpu-usage-in-python
|
|
222
|
+
def get_gpu_memory(gpu_id=0):
|
|
223
|
+
"""
|
|
224
|
+
Currently collects gpu memory info for a given gpu id.
|
|
225
|
+
"""
|
|
226
|
+
output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
|
|
227
|
+
ACCEPTABLE_AVAILABLE_MEMORY = 1024
|
|
228
|
+
COMMAND = "nvidia-smi --query-gpu=memory.used --format=csv"
|
|
229
|
+
try:
|
|
230
|
+
memory_use_info = output_to_list(subprocess.check_output(COMMAND.split(),stderr=subprocess.STDOUT))[1:]
|
|
231
|
+
except subprocess.CalledProcessError as e:
|
|
232
|
+
raise RuntimeError("command '{}' return with error (code {}): {}".format(e.cmd, e.returncode, e.output))
|
|
233
|
+
memory_use_values = [int(x.split()[0]) for i, x in enumerate(memory_use_info)]
|
|
234
|
+
|
|
235
|
+
#assert len(memory_use_values)==1, f"get_gpu_memory::memory_use_values should have only 1 value, now has {len(memory_use_values)} (memory_use_values)"
|
|
236
|
+
return memory_use_values[gpu_id]
|
|
237
|
+
|
|
238
|
+
class HardwareStatus():
|
|
239
|
+
def __init__(self):
|
|
240
|
+
self.gpu_mem = []
|
|
241
|
+
self.timer = None
|
|
242
|
+
|
|
243
|
+
def get_status_every_sec(self, gpu_id=0):
|
|
244
|
+
"""
|
|
245
|
+
This function calls itself every 1 sec and appends the gpu_memory.
|
|
246
|
+
"""
|
|
247
|
+
self.timer = Timer(1.0, self.get_status_every_sec)
|
|
248
|
+
self.timer.start()
|
|
249
|
+
self.gpu_mem.append(get_gpu_memory(gpu_id))
|
|
250
|
+
# print('self.gpu_mem',self.gpu_mem)
|
|
251
|
+
|
|
252
|
+
def stop_timer(self):
|
|
253
|
+
self.timer.cancel()
|
|
254
|
+
|
|
255
|
+
|