deepliif 1.1.13__py3-none-any.whl → 1.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli.py +15 -22
- deepliif/data/aligned_dataset.py +2 -2
- deepliif/models/DeepLIIFKD_model.py +409 -0
- deepliif/models/__init__ - multiprocessing (failure).py +980 -0
- deepliif/models/__init__ - weights, empty, zarr, tile count.py +792 -0
- deepliif/models/__init__.py +130 -70
- deepliif/models/base_model.py +1 -1
- deepliif/models/networks.py +7 -5
- deepliif/postprocessing.py +55 -24
- deepliif/util/__init__.py +103 -16
- deepliif/util/checks.py +17 -0
- deepliif/util/util - modified tensor_to_pil.py +255 -0
- deepliif/util/util.py +42 -0
- {deepliif-1.1.13.dist-info → deepliif-1.1.15.dist-info}/METADATA +628 -622
- {deepliif-1.1.13.dist-info → deepliif-1.1.15.dist-info}/RECORD +19 -14
- {deepliif-1.1.13.dist-info → deepliif-1.1.15.dist-info}/WHEEL +5 -5
- {deepliif-1.1.13.dist-info → deepliif-1.1.15.dist-info}/LICENSE.md +0 -0
- {deepliif-1.1.13.dist-info → deepliif-1.1.15.dist-info}/entry_points.txt +0 -0
- {deepliif-1.1.13.dist-info → deepliif-1.1.15.dist-info}/top_level.txt +0 -0
deepliif/util/__init__.py
CHANGED
|
@@ -2,6 +2,10 @@
|
|
|
2
2
|
import os
|
|
3
3
|
import collections
|
|
4
4
|
|
|
5
|
+
import atexit
|
|
6
|
+
import functools
|
|
7
|
+
import threading
|
|
8
|
+
|
|
5
9
|
import torch
|
|
6
10
|
import numpy as np
|
|
7
11
|
from PIL import Image, ImageOps
|
|
@@ -22,6 +26,9 @@ import javabridge
|
|
|
22
26
|
import bioformats.omexml as ome
|
|
23
27
|
import tifffile as tf
|
|
24
28
|
|
|
29
|
+
from tifffile import TiffFile
|
|
30
|
+
import zarr
|
|
31
|
+
|
|
25
32
|
|
|
26
33
|
excluding_names = ['Hema', 'DAPI', 'DAPILap2', 'Ki67', 'Seg', 'Marked', 'SegRefined', 'SegOverlaid', 'Marker', 'Lap2']
|
|
27
34
|
# Image extensions to consider
|
|
@@ -392,6 +399,30 @@ def image_variance_rgb(img):
|
|
|
392
399
|
return var
|
|
393
400
|
|
|
394
401
|
|
|
402
|
+
def init_javabridge_bioformats():
|
|
403
|
+
"""
|
|
404
|
+
Initialize javabridge for use with bioformats.
|
|
405
|
+
Run as daemon so no need to explicitly call kill_vm.
|
|
406
|
+
This function will only run once; repeat calls do nothing.
|
|
407
|
+
"""
|
|
408
|
+
|
|
409
|
+
if not hasattr(init_javabridge_bioformats, 'called'):
|
|
410
|
+
# https://github.com/LeeKamentsky/python-javabridge/issues/155
|
|
411
|
+
old_init = threading.Thread.__init__
|
|
412
|
+
threading.Thread.__init__ = functools.partialmethod(old_init, daemon=True)
|
|
413
|
+
javabridge.start_vm(class_path=bioformats.JARS)
|
|
414
|
+
threading.Thread.__init__ = old_init
|
|
415
|
+
atexit.register(javabridge.kill_vm)
|
|
416
|
+
|
|
417
|
+
rootLoggerName = javabridge.get_static_field("org/slf4j/Logger", "ROOT_LOGGER_NAME", "Ljava/lang/String;")
|
|
418
|
+
rootLogger = javabridge.static_call("org/slf4j/LoggerFactory", "getLogger",
|
|
419
|
+
"(Ljava/lang/String;)Lorg/slf4j/Logger;", rootLoggerName)
|
|
420
|
+
logLevel = javabridge.get_static_field("ch/qos/logback/classic/Level", "WARN", "Lch/qos/logback/classic/Level;")
|
|
421
|
+
javabridge.call(rootLogger, "setLevel", "(Lch/qos/logback/classic/Level;)V", logLevel)
|
|
422
|
+
|
|
423
|
+
init_javabridge_bioformats.called = True
|
|
424
|
+
|
|
425
|
+
|
|
395
426
|
def read_bioformats_image_with_reader(path, channel=0, region=(0, 0, 0, 0)):
|
|
396
427
|
"""
|
|
397
428
|
Using this function, you can read a specific region of a large image by giving the region bounding box (XYWH format)
|
|
@@ -402,14 +433,7 @@ def read_bioformats_image_with_reader(path, channel=0, region=(0, 0, 0, 0)):
|
|
|
402
433
|
:param region: The bounding box around the region of interest (XYWH format).
|
|
403
434
|
:return: The specified region of interest image (numpy array).
|
|
404
435
|
"""
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
rootLoggerName = javabridge.get_static_field("org/slf4j/Logger", "ROOT_LOGGER_NAME", "Ljava/lang/String;")
|
|
408
|
-
rootLogger = javabridge.static_call("org/slf4j/LoggerFactory", "getLogger",
|
|
409
|
-
"(Ljava/lang/String;)Lorg/slf4j/Logger;", rootLoggerName)
|
|
410
|
-
logLevel = javabridge.get_static_field("ch/qos/logback/classic/Level", "WARN", "Lch/qos/logback/classic/Level;")
|
|
411
|
-
javabridge.call(rootLogger, "setLevel", "(Lch/qos/logback/classic/Level;)V", logLevel)
|
|
412
|
-
|
|
436
|
+
init_javabridge_bioformats()
|
|
413
437
|
with bioformats.ImageReader(path) as reader:
|
|
414
438
|
return reader.read(t=channel, XYWH=region)
|
|
415
439
|
|
|
@@ -421,14 +445,7 @@ def get_information(filename):
|
|
|
421
445
|
:param filename: The address to the ome image.
|
|
422
446
|
:return: size_x, size_y, size_z, size_c, size_t, pixel_type
|
|
423
447
|
"""
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
rootLoggerName = javabridge.get_static_field("org/slf4j/Logger", "ROOT_LOGGER_NAME", "Ljava/lang/String;")
|
|
427
|
-
rootLogger = javabridge.static_call("org/slf4j/LoggerFactory", "getLogger",
|
|
428
|
-
"(Ljava/lang/String;)Lorg/slf4j/Logger;", rootLoggerName)
|
|
429
|
-
logLevel = javabridge.get_static_field("ch/qos/logback/classic/Level", "WARN", "Lch/qos/logback/classic/Level;")
|
|
430
|
-
javabridge.call(rootLogger, "setLevel", "(Lch/qos/logback/classic/Level;)V", logLevel)
|
|
431
|
-
|
|
448
|
+
init_javabridge_bioformats()
|
|
432
449
|
metadata = bioformats.get_omexml_metadata(filename)
|
|
433
450
|
omexml = bioformats.OMEXML(metadata)
|
|
434
451
|
size_x, size_y, size_z, size_c, size_t, pixel_type = omexml.image().Pixels.SizeX, \
|
|
@@ -441,6 +458,76 @@ def get_information(filename):
|
|
|
441
458
|
return size_x, size_y, size_z, size_c, size_t, pixel_type
|
|
442
459
|
|
|
443
460
|
|
|
461
|
+
class WSIReader:
|
|
462
|
+
"""
|
|
463
|
+
Assumes the file is a single image (e.g., not a stacked
|
|
464
|
+
OME TIFF) and will always return uint8 pixel type data.
|
|
465
|
+
"""
|
|
466
|
+
|
|
467
|
+
def __init__(self, path):
|
|
468
|
+
init_javabridge_bioformats()
|
|
469
|
+
metadata = bioformats.get_omexml_metadata(path)
|
|
470
|
+
omexml = bioformats.OMEXML(metadata)
|
|
471
|
+
|
|
472
|
+
self._path = path
|
|
473
|
+
self._width = omexml.image().Pixels.SizeX
|
|
474
|
+
self._height = omexml.image().Pixels.SizeY
|
|
475
|
+
self._pixel_type = omexml.image().Pixels.PixelType
|
|
476
|
+
|
|
477
|
+
self._tif = None
|
|
478
|
+
if self._pixel_type == 'uint8':
|
|
479
|
+
try:
|
|
480
|
+
self._file = None
|
|
481
|
+
self._file = open(path, 'rb')
|
|
482
|
+
self._tif = TiffFile(self._file)
|
|
483
|
+
self._zarr = zarr.open(self._tif.pages[0].aszarr(), mode='r')
|
|
484
|
+
except Exception as e:
|
|
485
|
+
if self._tif is not None:
|
|
486
|
+
self._tif.close()
|
|
487
|
+
self._tif = None
|
|
488
|
+
if self._file is not None:
|
|
489
|
+
self._file.close()
|
|
490
|
+
|
|
491
|
+
self._bfreader = None
|
|
492
|
+
if self._tif is None:
|
|
493
|
+
self._rescale = (self._pixel_type != 'uint8')
|
|
494
|
+
self._bfreader = bioformats.ImageReader(path)
|
|
495
|
+
|
|
496
|
+
if self._tif is None and self._bfreader is None:
|
|
497
|
+
raise Exception('Cannot read WSI file.')
|
|
498
|
+
|
|
499
|
+
def __enter__(self):
|
|
500
|
+
return self
|
|
501
|
+
|
|
502
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
503
|
+
self.close()
|
|
504
|
+
|
|
505
|
+
def close(self):
|
|
506
|
+
if self._tif is not None:
|
|
507
|
+
self._tif.close()
|
|
508
|
+
self._file.close()
|
|
509
|
+
if self._bfreader is not None:
|
|
510
|
+
self._bfreader.close()
|
|
511
|
+
|
|
512
|
+
@property
|
|
513
|
+
def width(self):
|
|
514
|
+
return self._width
|
|
515
|
+
|
|
516
|
+
@property
|
|
517
|
+
def height(self):
|
|
518
|
+
return self._height
|
|
519
|
+
|
|
520
|
+
def read(self, xywh):
|
|
521
|
+
if self._tif is not None:
|
|
522
|
+
x, y, w, h = xywh
|
|
523
|
+
return self._zarr[y:y+h, x:x+w]
|
|
524
|
+
|
|
525
|
+
px = self._bfreader.read(XYWH=xywh, rescale=self._rescale)
|
|
526
|
+
if self._rescale:
|
|
527
|
+
px = (px * 255).astype(np.uint8)
|
|
528
|
+
return px
|
|
529
|
+
|
|
530
|
+
|
|
444
531
|
|
|
445
532
|
|
|
446
533
|
def write_results_to_pickle_file(output_addr, results):
|
deepliif/util/checks.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
def check_weights(model, modalities_no, seg_weights, loss_weights_g, loss_weights_d):
|
|
4
|
+
assert sum(seg_weights) == 1, 'seg weights should add up to 1'
|
|
5
|
+
assert sum(loss_weights_g) == 1, 'loss weights g should add up to 1'
|
|
6
|
+
assert sum(loss_weights_d) == 1, 'loss weights d should add up to 1'
|
|
7
|
+
|
|
8
|
+
if model in ['DeepLIIF','DeepLIIFKD']:
|
|
9
|
+
# +1 because input becomes an additional modality used in generating the final segmentation
|
|
10
|
+
assert len(seg_weights) == modalities_no+1, 'seg weights should have the same number of elements as number of modalities to be generated'
|
|
11
|
+
assert len(loss_weights_g) == modalities_no+1, 'loss weights g should have the same number of elements as number of modalities to be generated'
|
|
12
|
+
assert len(loss_weights_d) == modalities_no+1, 'loss weights d should have the same number of elements as number of modalities to be generated'
|
|
13
|
+
|
|
14
|
+
else:
|
|
15
|
+
assert len(seg_weights) == modalities_no, 'seg weights should have the same number of elements as number of modalities to be generated'
|
|
16
|
+
assert len(loss_weights_g) == modalities_no, 'loss weights g should have the same number of elements as number of modalities to be generated'
|
|
17
|
+
assert len(loss_weights_d) == modalities_no, 'loss weights d should have the same number of elements as number of modalities to be generated'
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
"""This module contains simple helper functions """
|
|
2
|
+
import os
|
|
3
|
+
from time import time
|
|
4
|
+
from functools import wraps
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import numpy as np
|
|
8
|
+
from PIL import Image
|
|
9
|
+
import cv2
|
|
10
|
+
from skimage.metrics import structural_similarity as ssim
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def timeit(f):
|
|
14
|
+
@wraps(f)
|
|
15
|
+
def wrap(*args, **kwargs):
|
|
16
|
+
ts = time()
|
|
17
|
+
result = f(*args, **kwargs)
|
|
18
|
+
print(f'{f.__name__} {time() - ts}')
|
|
19
|
+
|
|
20
|
+
return result
|
|
21
|
+
|
|
22
|
+
return wrap
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def diagnose_network(net, name='network'):
|
|
26
|
+
"""Calculate and print the mean of average absolute(gradients)
|
|
27
|
+
|
|
28
|
+
Parameters:
|
|
29
|
+
net (torch network) -- Torch network
|
|
30
|
+
name (str) -- the name of the network
|
|
31
|
+
"""
|
|
32
|
+
mean = 0.0
|
|
33
|
+
count = 0
|
|
34
|
+
for param in net.parameters():
|
|
35
|
+
if param.grad is not None:
|
|
36
|
+
mean += torch.mean(torch.abs(param.grad.data))
|
|
37
|
+
count += 1
|
|
38
|
+
if count > 0:
|
|
39
|
+
mean = mean / count
|
|
40
|
+
print(name)
|
|
41
|
+
print(mean)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def save_image(image_numpy, image_path, aspect_ratio=1.0):
|
|
45
|
+
"""Save a numpy image to the disk
|
|
46
|
+
|
|
47
|
+
Parameters:
|
|
48
|
+
image_numpy (numpy array) -- input numpy array
|
|
49
|
+
image_path (str) -- the path of the image
|
|
50
|
+
"""
|
|
51
|
+
x, y, nc = image_numpy.shape
|
|
52
|
+
|
|
53
|
+
if nc > 3:
|
|
54
|
+
if nc % 3 == 0:
|
|
55
|
+
nc_img = 3
|
|
56
|
+
no_img = nc // nc_img
|
|
57
|
+
|
|
58
|
+
elif nc % 2 == 0:
|
|
59
|
+
nc_img = 2
|
|
60
|
+
no_img = nc // nc_img
|
|
61
|
+
else:
|
|
62
|
+
nc_img = 1
|
|
63
|
+
no_img = nc // nc_img
|
|
64
|
+
print(f'image (numpy) has {nc}>3 channels, inferred to have {no_img} images each with {nc_img} channel(s)')
|
|
65
|
+
l_image_numpy = np.dsplit(image_numpy,[nc_img*i for i in range(1,no_img)])
|
|
66
|
+
image_numpy = np.concatenate(l_image_numpy, axis=1) # stack horizontally
|
|
67
|
+
|
|
68
|
+
image_pil = Image.fromarray(image_numpy)
|
|
69
|
+
h, w, _ = image_numpy.shape
|
|
70
|
+
|
|
71
|
+
if aspect_ratio > 1.0:
|
|
72
|
+
image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
|
|
73
|
+
if aspect_ratio < 1.0:
|
|
74
|
+
image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
|
|
75
|
+
image_pil.save(image_path)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def print_numpy(x, val=True, shp=False):
|
|
79
|
+
"""Print the mean, min, max, median, std, and size of a numpy array
|
|
80
|
+
|
|
81
|
+
Parameters:
|
|
82
|
+
val (bool) -- if print the values of the numpy array
|
|
83
|
+
shp (bool) -- if print the shape of the numpy array
|
|
84
|
+
"""
|
|
85
|
+
x = x.astype(np.float64)
|
|
86
|
+
if shp:
|
|
87
|
+
print('shape,', x.shape)
|
|
88
|
+
if val:
|
|
89
|
+
x = x.flatten()
|
|
90
|
+
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
|
|
91
|
+
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def mkdirs(paths):
|
|
95
|
+
"""create empty directories if they don't exist
|
|
96
|
+
|
|
97
|
+
Parameters:
|
|
98
|
+
paths (str list) -- a list of directory paths
|
|
99
|
+
"""
|
|
100
|
+
if isinstance(paths, list) and not isinstance(paths, str):
|
|
101
|
+
for path in paths:
|
|
102
|
+
mkdir(path)
|
|
103
|
+
else:
|
|
104
|
+
mkdir(paths)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def mkdir(path):
|
|
108
|
+
"""create a single empty directory if it didn't exist
|
|
109
|
+
|
|
110
|
+
Parameters:
|
|
111
|
+
path (str) -- a single directory path
|
|
112
|
+
"""
|
|
113
|
+
if not os.path.exists(path):
|
|
114
|
+
os.makedirs(path, exist_ok=True)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
import time
|
|
118
|
+
time_tensor = 0
|
|
119
|
+
time_convert = 0
|
|
120
|
+
time_transpose = 0
|
|
121
|
+
time_astype = 0
|
|
122
|
+
time_topil = 0
|
|
123
|
+
time_scale = 0
|
|
124
|
+
def print_times():
|
|
125
|
+
print('Time to get tensor:', round(time_tensor, 1), flush=True)
|
|
126
|
+
print('Time to convert:', round(time_convert, 1), flush=True)
|
|
127
|
+
print('Time to transpose:', round(time_transpose, 1), flush=True)
|
|
128
|
+
print('Time to scale:', round(time_scale, 1), flush=True)
|
|
129
|
+
print('Time for astype:', round(time_transpose, 1), flush=True)
|
|
130
|
+
print('Time to pil:', round(time_topil, 1), flush=True)
|
|
131
|
+
|
|
132
|
+
def tensor2im(input_image, imtype=np.uint8):
|
|
133
|
+
""""Converts a Tensor array into a numpy image array.
|
|
134
|
+
|
|
135
|
+
Parameters:
|
|
136
|
+
input_image (tensor) -- the input image tensor array
|
|
137
|
+
imtype (type) -- the desired type of the converted numpy array
|
|
138
|
+
"""
|
|
139
|
+
if not isinstance(input_image, np.ndarray):
|
|
140
|
+
if isinstance(input_image, torch.Tensor): # get the data from a variable
|
|
141
|
+
ts = time.time()
|
|
142
|
+
image_tensor = input_image.data
|
|
143
|
+
te = time.time()
|
|
144
|
+
global time_tensor
|
|
145
|
+
time_tensor += (te - ts)
|
|
146
|
+
else:
|
|
147
|
+
return input_image
|
|
148
|
+
ts = time.time()
|
|
149
|
+
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
|
|
150
|
+
te = time.time()
|
|
151
|
+
global time_convert
|
|
152
|
+
time_convert += (te - ts)
|
|
153
|
+
if image_numpy.shape[0] == 1: # grayscale to RGB
|
|
154
|
+
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
|
155
|
+
ts = time.time()
|
|
156
|
+
#image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
|
|
157
|
+
image_numpy = np.transpose(image_numpy, (1, 2, 0))
|
|
158
|
+
te = time.time()
|
|
159
|
+
global time_transpose
|
|
160
|
+
time_transpose += (te - ts)
|
|
161
|
+
ts = time.time()
|
|
162
|
+
image_numpy = cv2.resize(image_numpy, (256, 256), interpolation=cv2.INTER_AREA)
|
|
163
|
+
image_numpy = (image_numpy + 1) / 2.0 * 255.0
|
|
164
|
+
te = time.time()
|
|
165
|
+
global time_scale
|
|
166
|
+
time_scale += (te - ts)
|
|
167
|
+
else: # if it is a numpy array, do nothing
|
|
168
|
+
image_numpy = input_image
|
|
169
|
+
return image_numpy.astype(imtype)
|
|
170
|
+
ts = time.time()
|
|
171
|
+
image_numpy = image_numpy.astype(imtype)
|
|
172
|
+
te = time.time()
|
|
173
|
+
global time_astype
|
|
174
|
+
time_astype += (te - ts)
|
|
175
|
+
return image_numpy
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def tensor_to_pil(t):
|
|
179
|
+
#return Image.fromarray(tensor2im(t))
|
|
180
|
+
arr = tensor2im(t)
|
|
181
|
+
ts = time.time()
|
|
182
|
+
#arr = cv2.resize(arr, (256, 256), interpolation=cv2.INTER_CUBIC)
|
|
183
|
+
im = Image.fromarray(arr)
|
|
184
|
+
te = time.time()
|
|
185
|
+
global time_topil
|
|
186
|
+
time_topil += (te - ts)
|
|
187
|
+
return im
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def calculate_ssim(img1, img2):
|
|
191
|
+
return ssim(img1, img2, data_range=img2.max() - img2.min())
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def check_multi_scale(img1, img2):
|
|
195
|
+
img1 = np.array(img1)
|
|
196
|
+
img2 = np.array(img2)
|
|
197
|
+
max_ssim = (512, 0)
|
|
198
|
+
for tile_size in range(100, 1000, 100):
|
|
199
|
+
image_ssim = 0
|
|
200
|
+
tile_no = 0
|
|
201
|
+
for i in range(0, img2.shape[0], tile_size):
|
|
202
|
+
for j in range(0, img2.shape[1], tile_size):
|
|
203
|
+
if i + tile_size <= img2.shape[0] and j + tile_size <= img2.shape[1]:
|
|
204
|
+
tile = img2[i: i + tile_size, j: j + tile_size]
|
|
205
|
+
tile = cv2.resize(tile, (img1.shape[0], img1.shape[1]))
|
|
206
|
+
tile_ssim = calculate_ssim(img1, tile)
|
|
207
|
+
image_ssim += tile_ssim
|
|
208
|
+
tile_no += 1
|
|
209
|
+
if tile_no > 0:
|
|
210
|
+
image_ssim /= tile_no
|
|
211
|
+
if max_ssim[1] < image_ssim:
|
|
212
|
+
max_ssim = (tile_size, image_ssim)
|
|
213
|
+
return max_ssim[0]
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
import subprocess
|
|
217
|
+
import os
|
|
218
|
+
from threading import Thread , Timer
|
|
219
|
+
import sched, time
|
|
220
|
+
|
|
221
|
+
# modified from https://stackoverflow.com/questions/67707828/how-to-get-every-seconds-gpu-usage-in-python
|
|
222
|
+
def get_gpu_memory(gpu_id=0):
|
|
223
|
+
"""
|
|
224
|
+
Currently collects gpu memory info for a given gpu id.
|
|
225
|
+
"""
|
|
226
|
+
output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
|
|
227
|
+
ACCEPTABLE_AVAILABLE_MEMORY = 1024
|
|
228
|
+
COMMAND = "nvidia-smi --query-gpu=memory.used --format=csv"
|
|
229
|
+
try:
|
|
230
|
+
memory_use_info = output_to_list(subprocess.check_output(COMMAND.split(),stderr=subprocess.STDOUT))[1:]
|
|
231
|
+
except subprocess.CalledProcessError as e:
|
|
232
|
+
raise RuntimeError("command '{}' return with error (code {}): {}".format(e.cmd, e.returncode, e.output))
|
|
233
|
+
memory_use_values = [int(x.split()[0]) for i, x in enumerate(memory_use_info)]
|
|
234
|
+
|
|
235
|
+
#assert len(memory_use_values)==1, f"get_gpu_memory::memory_use_values should have only 1 value, now has {len(memory_use_values)} (memory_use_values)"
|
|
236
|
+
return memory_use_values[gpu_id]
|
|
237
|
+
|
|
238
|
+
class HardwareStatus():
|
|
239
|
+
def __init__(self):
|
|
240
|
+
self.gpu_mem = []
|
|
241
|
+
self.timer = None
|
|
242
|
+
|
|
243
|
+
def get_status_every_sec(self, gpu_id=0):
|
|
244
|
+
"""
|
|
245
|
+
This function calls itself every 1 sec and appends the gpu_memory.
|
|
246
|
+
"""
|
|
247
|
+
self.timer = Timer(1.0, self.get_status_every_sec)
|
|
248
|
+
self.timer.start()
|
|
249
|
+
self.gpu_mem.append(get_gpu_memory(gpu_id))
|
|
250
|
+
# print('self.gpu_mem',self.gpu_mem)
|
|
251
|
+
|
|
252
|
+
def stop_timer(self):
|
|
253
|
+
self.timer.cancel()
|
|
254
|
+
|
|
255
|
+
|
deepliif/util/util.py
CHANGED
|
@@ -163,3 +163,45 @@ def check_multi_scale(img1, img2):
|
|
|
163
163
|
if max_ssim[1] < image_ssim:
|
|
164
164
|
max_ssim = (tile_size, image_ssim)
|
|
165
165
|
return max_ssim[0]
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
import subprocess
|
|
169
|
+
import os
|
|
170
|
+
from threading import Thread , Timer
|
|
171
|
+
import sched, time
|
|
172
|
+
|
|
173
|
+
# modified from https://stackoverflow.com/questions/67707828/how-to-get-every-seconds-gpu-usage-in-python
|
|
174
|
+
def get_gpu_memory(gpu_id=0):
|
|
175
|
+
"""
|
|
176
|
+
Currently collects gpu memory info for a given gpu id.
|
|
177
|
+
"""
|
|
178
|
+
output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
|
|
179
|
+
ACCEPTABLE_AVAILABLE_MEMORY = 1024
|
|
180
|
+
COMMAND = "nvidia-smi --query-gpu=memory.used --format=csv"
|
|
181
|
+
try:
|
|
182
|
+
memory_use_info = output_to_list(subprocess.check_output(COMMAND.split(),stderr=subprocess.STDOUT))[1:]
|
|
183
|
+
except subprocess.CalledProcessError as e:
|
|
184
|
+
raise RuntimeError("command '{}' return with error (code {}): {}".format(e.cmd, e.returncode, e.output))
|
|
185
|
+
memory_use_values = [int(x.split()[0]) for i, x in enumerate(memory_use_info)]
|
|
186
|
+
|
|
187
|
+
#assert len(memory_use_values)==1, f"get_gpu_memory::memory_use_values should have only 1 value, now has {len(memory_use_values)} (memory_use_values)"
|
|
188
|
+
return memory_use_values[gpu_id]
|
|
189
|
+
|
|
190
|
+
class HardwareStatus():
|
|
191
|
+
def __init__(self):
|
|
192
|
+
self.gpu_mem = []
|
|
193
|
+
self.timer = None
|
|
194
|
+
|
|
195
|
+
def get_status_every_sec(self, gpu_id=0):
|
|
196
|
+
"""
|
|
197
|
+
This function calls itself every 1 sec and appends the gpu_memory.
|
|
198
|
+
"""
|
|
199
|
+
self.timer = Timer(1.0, self.get_status_every_sec)
|
|
200
|
+
self.timer.start()
|
|
201
|
+
self.gpu_mem.append(get_gpu_memory(gpu_id))
|
|
202
|
+
# print('self.gpu_mem',self.gpu_mem)
|
|
203
|
+
|
|
204
|
+
def stop_timer(self):
|
|
205
|
+
self.timer.cancel()
|
|
206
|
+
|
|
207
|
+
|