deepliif 1.1.14__py3-none-any.whl → 1.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -219,13 +219,14 @@ def compute_overlap(img_size, tile_size):
219
219
  return tile_size // 4
220
220
 
221
221
 
222
- def run_torchserve(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, use_dask=True, output_tensor=False):
222
+ def run_torchserve(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, seg_weights=None, use_dask=True, output_tensor=False):
223
223
  """
224
224
  eager_mode: not used in this function; put in place to be consistent with run_dask
225
225
  so that run_wrapper() could call either this function or run_dask with
226
226
  same syntax
227
227
  opt: same as eager_mode
228
228
  seg_only: same as eager_mode
229
+ seg_weights: same as eager_mode
229
230
  nets: same as eager_mode
230
231
  """
231
232
  buffer = BytesIO()
@@ -245,7 +246,7 @@ def run_torchserve(img, model_path=None, nets=None, eager_mode=False, opt=None,
245
246
  return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
246
247
 
247
248
 
248
- def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, use_dask=True, output_tensor=False):
249
+ def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, seg_weights=None, use_dask=True, output_tensor=False):
249
250
  """
250
251
  Provide either the model path or the networks object.
251
252
 
@@ -280,20 +281,22 @@ def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_on
280
281
  return model(input.to(next(model.parameters()).device))
281
282
 
282
283
  if opt.model in ['DeepLIIF','DeepLIIFKD']:
283
- #weights = {
284
- # 'G51': 0.25, # IHC
285
- # 'G52': 0.25, # Hema
286
- # 'G53': 0.25, # DAPI
287
- # 'G54': 0.00, # Lap2
288
- # 'G55': 0.25, # Marker
289
- #}
290
- weights = {
291
- 'G51': 0.5, # IHC
292
- 'G52': 0.0, # Hema
293
- 'G53': 0.0, # DAPI
294
- 'G54': 0.0, # Lap2
295
- 'G55': 0.5, # Marker
296
- }
284
+ if seg_weights is None:
285
+ weights = {
286
+ 'G51': 0.25, # IHC
287
+ 'G52': 0.25, # Hema
288
+ 'G53': 0.25, # DAPI
289
+ 'G54': 0.00, # Lap2
290
+ 'G55': 0.25, # Marker
291
+ }
292
+ else:
293
+ weights = {
294
+ 'G51': seg_weights[0], # IHC
295
+ 'G52': seg_weights[1], # Hema
296
+ 'G53': seg_weights[2], # DAPI
297
+ 'G54': seg_weights[3], # Lap2
298
+ 'G55': seg_weights[4], # Marker
299
+ }
297
300
 
298
301
  seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
299
302
  if seg_only:
@@ -357,12 +360,12 @@ def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_on
357
360
  def is_empty(tile):
358
361
  thresh = 15
359
362
  if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
360
- return all([True if np.max(image_variance_rgb(t)) < thresh else False for t in tile])
363
+ return all([True if image_variance_gray(t) < thresh else False for t in tile])
361
364
  else:
362
- return True if np.max(image_variance_rgb(tile)) < thresh else False
365
+ return True if image_variance_gray(tile) < thresh else False
363
366
 
364
367
 
365
- def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, use_dask=True, output_tensor=False):
368
+ def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, seg_weights=None, use_dask=True, output_tensor=False):
366
369
  if opt.model in ['DeepLIIF','DeepLIIFKD']:
367
370
  if is_empty(tile):
368
371
  if seg_only:
@@ -384,7 +387,7 @@ def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt=
384
387
  'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
385
388
  }
386
389
  else:
387
- return run_fn(tile, model_path, None, eager_mode, opt, seg_only)
390
+ return run_fn(tile, model_path, None, eager_mode, opt, seg_only, seg_weights)
388
391
  elif opt.model in ['DeepLIIFExt', 'SDG']:
389
392
  if is_empty(tile):
390
393
  res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
@@ -405,7 +408,7 @@ def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt=
405
408
 
406
409
  def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
407
410
  eager_mode=False, color_dapi=False, color_marker=False, opt=None,
408
- return_seg_intermediate=False, seg_only=False, opt_args={}):
411
+ return_seg_intermediate=False, seg_only=False, seg_weights=None, opt_args={}):
409
412
  """
410
413
  opt_args: a dictionary of key and values to add/overwrite to opt
411
414
  """
@@ -430,7 +433,7 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
430
433
 
431
434
  tiler = InferenceTiler(orig, tile_size, overlap_size)
432
435
  for tile in tiler:
433
- tiler.stitch(run_wrapper(tile, run_fn, model_path, None, eager_mode, opt, seg_only))
436
+ tiler.stitch(run_wrapper(tile, run_fn, model_path, None, eager_mode, opt, seg_only, seg_weights))
434
437
 
435
438
  results = tiler.results()
436
439
 
@@ -515,7 +518,7 @@ def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='def
515
518
 
516
519
  def infer_modalities(img, tile_size, model_dir, eager_mode=False,
517
520
  color_dapi=False, color_marker=False, opt=None,
518
- return_seg_intermediate=False, seg_only=False):
521
+ return_seg_intermediate=False, seg_only=False, seg_weights=None):
519
522
  """
520
523
  This function is used to infer modalities for the given image using a trained model.
521
524
  :param img: The input image.
@@ -543,7 +546,8 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
543
546
  color_marker=color_marker,
544
547
  opt=opt,
545
548
  return_seg_intermediate=return_seg_intermediate,
546
- seg_only=seg_only
549
+ seg_only=seg_only,
550
+ seg_weights=seg_weights,
547
551
  )
548
552
 
549
553
  if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
@@ -558,7 +562,7 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
558
562
  return images, None
559
563
 
560
564
 
561
- def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000, color_dapi=False, color_marker=False, seg_intermediate=False, seg_only=False):
565
+ def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000, color_dapi=False, color_marker=False, seg_intermediate=False, seg_only=False, seg_weights=None):
562
566
  """
563
567
  This function infers modalities and segmentation mask for the given WSI image. It
564
568
 
@@ -592,7 +596,7 @@ def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size,
592
596
  region = reader.read(XYWH=region_XYWH, rescale=rescale)
593
597
  img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
594
598
 
595
- region_modalities, region_scoring = infer_modalities(img, tile_size, model_dir, color_dapi=color_dapi, color_marker=color_marker, return_seg_intermediate=seg_intermediate, seg_only=seg_only)
599
+ region_modalities, region_scoring = infer_modalities(img, tile_size, model_dir, color_dapi=color_dapi, color_marker=color_marker, return_seg_intermediate=seg_intermediate, seg_only=seg_only, seg_weights=seg_weights)
596
600
  if region_scoring is not None:
597
601
  if scoring is None:
598
602
  scoring = {
@@ -624,8 +628,6 @@ def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size,
624
628
  with open(os.path.join(results_dir, f'{basename}.json'), 'w') as f:
625
629
  json.dump(scoring, f, indent=2)
626
630
 
627
- javabridge.kill_vm()
628
-
629
631
 
630
632
  def get_wsi_resolution(filename):
631
633
  """
@@ -633,9 +635,6 @@ def get_wsi_resolution(filename):
633
635
  the corresponding tile size to use by default for DeepLIIF.
634
636
  If it cannot be found, return (None, None) instead.
635
637
 
636
- Note: This will start the javabridge VM, but not kill it.
637
- It must be killed elsewhere.
638
-
639
638
  Parameters
640
639
  ----------
641
640
  filename : str
@@ -649,11 +648,10 @@ def get_wsi_resolution(filename):
649
648
  Corresponding tile size for DeepLIIF.
650
649
  """
651
650
 
652
- # make sure javabridge is already set up from with call to get_information()
653
- size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
651
+ init_javabridge_bioformats()
652
+ metadata = bioformats.get_omexml_metadata(filename)
654
653
 
655
654
  mag = None
656
- metadata = bioformats.get_omexml_metadata(filename)
657
655
  try:
658
656
  omexml = bioformats.OMEXML(metadata)
659
657
  mag = omexml.instrument().Objective.NominalMagnification
@@ -686,7 +684,7 @@ def get_wsi_resolution(filename):
686
684
  return None, None
687
685
 
688
686
 
689
- def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False):
687
+ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False, seg_weights=None):
690
688
  """
691
689
  Perform inference on a slide and get the results individual cell data.
692
690
 
@@ -704,6 +702,9 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
704
702
  Version of cell data to return (3 or 4).
705
703
  print_log : bool
706
704
  Whether or not to print updates while processing.
705
+ seg_weights : list | tuple
706
+ Optional five weights to use for creating the combined Seg image.
707
+ If None, then the default weights are used.
707
708
 
708
709
  Returns
709
710
  -------
@@ -717,22 +718,21 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
717
718
 
718
719
  resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
719
720
 
720
- size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(filename)
721
- rescale = (pixel_type != 'uint8')
722
- print_info('Info:', size_x, size_y, size_z, size_c, size_t, pixel_type)
723
-
724
- num_regions_x = math.ceil(size_x / region_size)
725
- num_regions_y = math.ceil(size_y / region_size)
726
- stride_x = math.ceil(size_x / num_regions_x)
727
- stride_y = math.ceil(size_y / num_regions_y)
728
- print_info('Strides:', stride_x, stride_y)
729
-
730
721
  data = None
731
722
  default_marker_thresh, count_marker_thresh = 0, 0
732
723
  default_size_thresh, count_size_thresh = 0, 0
733
724
 
734
- # javabridge already set up from previous call to get_information()
735
- with bioformats.ImageReader(filename) as reader:
725
+ with WSIReader(filename) as reader:
726
+ size_x = reader.width
727
+ size_y = reader.height
728
+ print_info('Info:', size_x, size_y)
729
+
730
+ num_regions_x = math.ceil(size_x / region_size)
731
+ num_regions_y = math.ceil(size_y / region_size)
732
+ stride_x = math.ceil(size_x / num_regions_x)
733
+ stride_y = math.ceil(size_y / num_regions_y)
734
+ print_info('Strides:', stride_x, stride_y)
735
+
736
736
  start_x, start_y = 0, 0
737
737
 
738
738
  while start_y < size_y:
@@ -740,9 +740,9 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
740
740
  region_XYWH = (start_x, start_y, min(stride_x, size_x-start_x), min(stride_y, size_y-start_y))
741
741
  print_info('Region:', region_XYWH)
742
742
 
743
- region = reader.read(XYWH=region_XYWH, rescale=rescale)
743
+ region = reader.read(region_XYWH)
744
744
  print_info(region.shape, region.dtype)
745
- img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
745
+ img = Image.fromarray(region)
746
746
  print_info(img.size, img.mode)
747
747
  del region
748
748
 
@@ -757,6 +757,7 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
757
757
  opt=None,
758
758
  return_seg_intermediate=False,
759
759
  seg_only=True,
760
+ seg_weights=seg_weights,
760
761
  )
761
762
  del img
762
763
 
@@ -795,8 +796,6 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
795
796
  start_x = 0
796
797
  start_y += stride_y
797
798
 
798
- javabridge.kill_vm()
799
-
800
799
  if count_marker_thresh == 0:
801
800
  count_marker_thresh = 1
802
801
  if count_size_thresh == 0:
@@ -804,6 +803,10 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
804
803
  data['settings']['default_marker_thresh'] = round(default_marker_thresh / count_marker_thresh)
805
804
  data['settings']['default_size_thresh'] = round(default_size_thresh / count_size_thresh)
806
805
 
806
+ data['settings']['tile_size'] = tile_size
807
+ data['settings']['region_size'] = region_size
808
+ data['settings']['seg_weights'] = seg_weights
809
+
807
810
  try:
808
811
  data['deepliifVersion'] = importlib.metadata.version('deepliif')
809
812
  except Exception as e:
deepliif/util/__init__.py CHANGED
@@ -2,6 +2,10 @@
2
2
  import os
3
3
  import collections
4
4
 
5
+ import atexit
6
+ import functools
7
+ import threading
8
+
5
9
  import torch
6
10
  import numpy as np
7
11
  from PIL import Image, ImageOps
@@ -22,6 +26,9 @@ import javabridge
22
26
  import bioformats.omexml as ome
23
27
  import tifffile as tf
24
28
 
29
+ from tifffile import TiffFile
30
+ import zarr
31
+
25
32
 
26
33
  excluding_names = ['Hema', 'DAPI', 'DAPILap2', 'Ki67', 'Seg', 'Marked', 'SegRefined', 'SegOverlaid', 'Marker', 'Lap2']
27
34
  # Image extensions to consider
@@ -392,6 +399,30 @@ def image_variance_rgb(img):
392
399
  return var
393
400
 
394
401
 
402
+ def init_javabridge_bioformats():
403
+ """
404
+ Initialize javabridge for use with bioformats.
405
+ Run as daemon so no need to explicitly call kill_vm.
406
+ This function will only run once; repeat calls do nothing.
407
+ """
408
+
409
+ if not hasattr(init_javabridge_bioformats, 'called'):
410
+ # https://github.com/LeeKamentsky/python-javabridge/issues/155
411
+ old_init = threading.Thread.__init__
412
+ threading.Thread.__init__ = functools.partialmethod(old_init, daemon=True)
413
+ javabridge.start_vm(class_path=bioformats.JARS)
414
+ threading.Thread.__init__ = old_init
415
+ atexit.register(javabridge.kill_vm)
416
+
417
+ rootLoggerName = javabridge.get_static_field("org/slf4j/Logger", "ROOT_LOGGER_NAME", "Ljava/lang/String;")
418
+ rootLogger = javabridge.static_call("org/slf4j/LoggerFactory", "getLogger",
419
+ "(Ljava/lang/String;)Lorg/slf4j/Logger;", rootLoggerName)
420
+ logLevel = javabridge.get_static_field("ch/qos/logback/classic/Level", "WARN", "Lch/qos/logback/classic/Level;")
421
+ javabridge.call(rootLogger, "setLevel", "(Lch/qos/logback/classic/Level;)V", logLevel)
422
+
423
+ init_javabridge_bioformats.called = True
424
+
425
+
395
426
  def read_bioformats_image_with_reader(path, channel=0, region=(0, 0, 0, 0)):
396
427
  """
397
428
  Using this function, you can read a specific region of a large image by giving the region bounding box (XYWH format)
@@ -402,14 +433,7 @@ def read_bioformats_image_with_reader(path, channel=0, region=(0, 0, 0, 0)):
402
433
  :param region: The bounding box around the region of interest (XYWH format).
403
434
  :return: The specified region of interest image (numpy array).
404
435
  """
405
- javabridge.start_vm(class_path=bioformats.JARS)
406
-
407
- rootLoggerName = javabridge.get_static_field("org/slf4j/Logger", "ROOT_LOGGER_NAME", "Ljava/lang/String;")
408
- rootLogger = javabridge.static_call("org/slf4j/LoggerFactory", "getLogger",
409
- "(Ljava/lang/String;)Lorg/slf4j/Logger;", rootLoggerName)
410
- logLevel = javabridge.get_static_field("ch/qos/logback/classic/Level", "WARN", "Lch/qos/logback/classic/Level;")
411
- javabridge.call(rootLogger, "setLevel", "(Lch/qos/logback/classic/Level;)V", logLevel)
412
-
436
+ init_javabridge_bioformats()
413
437
  with bioformats.ImageReader(path) as reader:
414
438
  return reader.read(t=channel, XYWH=region)
415
439
 
@@ -421,14 +445,7 @@ def get_information(filename):
421
445
  :param filename: The address to the ome image.
422
446
  :return: size_x, size_y, size_z, size_c, size_t, pixel_type
423
447
  """
424
- javabridge.start_vm(class_path=bioformats.JARS)
425
-
426
- rootLoggerName = javabridge.get_static_field("org/slf4j/Logger", "ROOT_LOGGER_NAME", "Ljava/lang/String;")
427
- rootLogger = javabridge.static_call("org/slf4j/LoggerFactory", "getLogger",
428
- "(Ljava/lang/String;)Lorg/slf4j/Logger;", rootLoggerName)
429
- logLevel = javabridge.get_static_field("ch/qos/logback/classic/Level", "WARN", "Lch/qos/logback/classic/Level;")
430
- javabridge.call(rootLogger, "setLevel", "(Lch/qos/logback/classic/Level;)V", logLevel)
431
-
448
+ init_javabridge_bioformats()
432
449
  metadata = bioformats.get_omexml_metadata(filename)
433
450
  omexml = bioformats.OMEXML(metadata)
434
451
  size_x, size_y, size_z, size_c, size_t, pixel_type = omexml.image().Pixels.SizeX, \
@@ -441,6 +458,76 @@ def get_information(filename):
441
458
  return size_x, size_y, size_z, size_c, size_t, pixel_type
442
459
 
443
460
 
461
+ class WSIReader:
462
+ """
463
+ Assumes the file is a single image (e.g., not a stacked
464
+ OME TIFF) and will always return uint8 pixel type data.
465
+ """
466
+
467
+ def __init__(self, path):
468
+ init_javabridge_bioformats()
469
+ metadata = bioformats.get_omexml_metadata(path)
470
+ omexml = bioformats.OMEXML(metadata)
471
+
472
+ self._path = path
473
+ self._width = omexml.image().Pixels.SizeX
474
+ self._height = omexml.image().Pixels.SizeY
475
+ self._pixel_type = omexml.image().Pixels.PixelType
476
+
477
+ self._tif = None
478
+ if self._pixel_type == 'uint8':
479
+ try:
480
+ self._file = None
481
+ self._file = open(path, 'rb')
482
+ self._tif = TiffFile(self._file)
483
+ self._zarr = zarr.open(self._tif.pages[0].aszarr(), mode='r')
484
+ except Exception as e:
485
+ if self._tif is not None:
486
+ self._tif.close()
487
+ self._tif = None
488
+ if self._file is not None:
489
+ self._file.close()
490
+
491
+ self._bfreader = None
492
+ if self._tif is None:
493
+ self._rescale = (self._pixel_type != 'uint8')
494
+ self._bfreader = bioformats.ImageReader(path)
495
+
496
+ if self._tif is None and self._bfreader is None:
497
+ raise Exception('Cannot read WSI file.')
498
+
499
+ def __enter__(self):
500
+ return self
501
+
502
+ def __exit__(self, exc_type, exc_value, traceback):
503
+ self.close()
504
+
505
+ def close(self):
506
+ if self._tif is not None:
507
+ self._tif.close()
508
+ self._file.close()
509
+ if self._bfreader is not None:
510
+ self._bfreader.close()
511
+
512
+ @property
513
+ def width(self):
514
+ return self._width
515
+
516
+ @property
517
+ def height(self):
518
+ return self._height
519
+
520
+ def read(self, xywh):
521
+ if self._tif is not None:
522
+ x, y, w, h = xywh
523
+ return self._zarr[y:y+h, x:x+w]
524
+
525
+ px = self._bfreader.read(XYWH=xywh, rescale=self._rescale)
526
+ if self._rescale:
527
+ px = (px * 255).astype(np.uint8)
528
+ return px
529
+
530
+
444
531
 
445
532
 
446
533
  def write_results_to_pickle_file(output_addr, results):
@@ -0,0 +1,255 @@
1
+ """This module contains simple helper functions """
2
+ import os
3
+ from time import time
4
+ from functools import wraps
5
+
6
+ import torch
7
+ import numpy as np
8
+ from PIL import Image
9
+ import cv2
10
+ from skimage.metrics import structural_similarity as ssim
11
+
12
+
13
+ def timeit(f):
14
+ @wraps(f)
15
+ def wrap(*args, **kwargs):
16
+ ts = time()
17
+ result = f(*args, **kwargs)
18
+ print(f'{f.__name__} {time() - ts}')
19
+
20
+ return result
21
+
22
+ return wrap
23
+
24
+
25
+ def diagnose_network(net, name='network'):
26
+ """Calculate and print the mean of average absolute(gradients)
27
+
28
+ Parameters:
29
+ net (torch network) -- Torch network
30
+ name (str) -- the name of the network
31
+ """
32
+ mean = 0.0
33
+ count = 0
34
+ for param in net.parameters():
35
+ if param.grad is not None:
36
+ mean += torch.mean(torch.abs(param.grad.data))
37
+ count += 1
38
+ if count > 0:
39
+ mean = mean / count
40
+ print(name)
41
+ print(mean)
42
+
43
+
44
+ def save_image(image_numpy, image_path, aspect_ratio=1.0):
45
+ """Save a numpy image to the disk
46
+
47
+ Parameters:
48
+ image_numpy (numpy array) -- input numpy array
49
+ image_path (str) -- the path of the image
50
+ """
51
+ x, y, nc = image_numpy.shape
52
+
53
+ if nc > 3:
54
+ if nc % 3 == 0:
55
+ nc_img = 3
56
+ no_img = nc // nc_img
57
+
58
+ elif nc % 2 == 0:
59
+ nc_img = 2
60
+ no_img = nc // nc_img
61
+ else:
62
+ nc_img = 1
63
+ no_img = nc // nc_img
64
+ print(f'image (numpy) has {nc}>3 channels, inferred to have {no_img} images each with {nc_img} channel(s)')
65
+ l_image_numpy = np.dsplit(image_numpy,[nc_img*i for i in range(1,no_img)])
66
+ image_numpy = np.concatenate(l_image_numpy, axis=1) # stack horizontally
67
+
68
+ image_pil = Image.fromarray(image_numpy)
69
+ h, w, _ = image_numpy.shape
70
+
71
+ if aspect_ratio > 1.0:
72
+ image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
73
+ if aspect_ratio < 1.0:
74
+ image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
75
+ image_pil.save(image_path)
76
+
77
+
78
+ def print_numpy(x, val=True, shp=False):
79
+ """Print the mean, min, max, median, std, and size of a numpy array
80
+
81
+ Parameters:
82
+ val (bool) -- if print the values of the numpy array
83
+ shp (bool) -- if print the shape of the numpy array
84
+ """
85
+ x = x.astype(np.float64)
86
+ if shp:
87
+ print('shape,', x.shape)
88
+ if val:
89
+ x = x.flatten()
90
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
91
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
92
+
93
+
94
+ def mkdirs(paths):
95
+ """create empty directories if they don't exist
96
+
97
+ Parameters:
98
+ paths (str list) -- a list of directory paths
99
+ """
100
+ if isinstance(paths, list) and not isinstance(paths, str):
101
+ for path in paths:
102
+ mkdir(path)
103
+ else:
104
+ mkdir(paths)
105
+
106
+
107
+ def mkdir(path):
108
+ """create a single empty directory if it didn't exist
109
+
110
+ Parameters:
111
+ path (str) -- a single directory path
112
+ """
113
+ if not os.path.exists(path):
114
+ os.makedirs(path, exist_ok=True)
115
+
116
+
117
+ import time
118
+ time_tensor = 0
119
+ time_convert = 0
120
+ time_transpose = 0
121
+ time_astype = 0
122
+ time_topil = 0
123
+ time_scale = 0
124
+ def print_times():
125
+ print('Time to get tensor:', round(time_tensor, 1), flush=True)
126
+ print('Time to convert:', round(time_convert, 1), flush=True)
127
+ print('Time to transpose:', round(time_transpose, 1), flush=True)
128
+ print('Time to scale:', round(time_scale, 1), flush=True)
129
+ print('Time for astype:', round(time_transpose, 1), flush=True)
130
+ print('Time to pil:', round(time_topil, 1), flush=True)
131
+
132
+ def tensor2im(input_image, imtype=np.uint8):
133
+ """"Converts a Tensor array into a numpy image array.
134
+
135
+ Parameters:
136
+ input_image (tensor) -- the input image tensor array
137
+ imtype (type) -- the desired type of the converted numpy array
138
+ """
139
+ if not isinstance(input_image, np.ndarray):
140
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
141
+ ts = time.time()
142
+ image_tensor = input_image.data
143
+ te = time.time()
144
+ global time_tensor
145
+ time_tensor += (te - ts)
146
+ else:
147
+ return input_image
148
+ ts = time.time()
149
+ image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
150
+ te = time.time()
151
+ global time_convert
152
+ time_convert += (te - ts)
153
+ if image_numpy.shape[0] == 1: # grayscale to RGB
154
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
155
+ ts = time.time()
156
+ #image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
157
+ image_numpy = np.transpose(image_numpy, (1, 2, 0))
158
+ te = time.time()
159
+ global time_transpose
160
+ time_transpose += (te - ts)
161
+ ts = time.time()
162
+ image_numpy = cv2.resize(image_numpy, (256, 256), interpolation=cv2.INTER_AREA)
163
+ image_numpy = (image_numpy + 1) / 2.0 * 255.0
164
+ te = time.time()
165
+ global time_scale
166
+ time_scale += (te - ts)
167
+ else: # if it is a numpy array, do nothing
168
+ image_numpy = input_image
169
+ return image_numpy.astype(imtype)
170
+ ts = time.time()
171
+ image_numpy = image_numpy.astype(imtype)
172
+ te = time.time()
173
+ global time_astype
174
+ time_astype += (te - ts)
175
+ return image_numpy
176
+
177
+
178
+ def tensor_to_pil(t):
179
+ #return Image.fromarray(tensor2im(t))
180
+ arr = tensor2im(t)
181
+ ts = time.time()
182
+ #arr = cv2.resize(arr, (256, 256), interpolation=cv2.INTER_CUBIC)
183
+ im = Image.fromarray(arr)
184
+ te = time.time()
185
+ global time_topil
186
+ time_topil += (te - ts)
187
+ return im
188
+
189
+
190
+ def calculate_ssim(img1, img2):
191
+ return ssim(img1, img2, data_range=img2.max() - img2.min())
192
+
193
+
194
+ def check_multi_scale(img1, img2):
195
+ img1 = np.array(img1)
196
+ img2 = np.array(img2)
197
+ max_ssim = (512, 0)
198
+ for tile_size in range(100, 1000, 100):
199
+ image_ssim = 0
200
+ tile_no = 0
201
+ for i in range(0, img2.shape[0], tile_size):
202
+ for j in range(0, img2.shape[1], tile_size):
203
+ if i + tile_size <= img2.shape[0] and j + tile_size <= img2.shape[1]:
204
+ tile = img2[i: i + tile_size, j: j + tile_size]
205
+ tile = cv2.resize(tile, (img1.shape[0], img1.shape[1]))
206
+ tile_ssim = calculate_ssim(img1, tile)
207
+ image_ssim += tile_ssim
208
+ tile_no += 1
209
+ if tile_no > 0:
210
+ image_ssim /= tile_no
211
+ if max_ssim[1] < image_ssim:
212
+ max_ssim = (tile_size, image_ssim)
213
+ return max_ssim[0]
214
+
215
+
216
+ import subprocess
217
+ import os
218
+ from threading import Thread , Timer
219
+ import sched, time
220
+
221
+ # modified from https://stackoverflow.com/questions/67707828/how-to-get-every-seconds-gpu-usage-in-python
222
+ def get_gpu_memory(gpu_id=0):
223
+ """
224
+ Currently collects gpu memory info for a given gpu id.
225
+ """
226
+ output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
227
+ ACCEPTABLE_AVAILABLE_MEMORY = 1024
228
+ COMMAND = "nvidia-smi --query-gpu=memory.used --format=csv"
229
+ try:
230
+ memory_use_info = output_to_list(subprocess.check_output(COMMAND.split(),stderr=subprocess.STDOUT))[1:]
231
+ except subprocess.CalledProcessError as e:
232
+ raise RuntimeError("command '{}' return with error (code {}): {}".format(e.cmd, e.returncode, e.output))
233
+ memory_use_values = [int(x.split()[0]) for i, x in enumerate(memory_use_info)]
234
+
235
+ #assert len(memory_use_values)==1, f"get_gpu_memory::memory_use_values should have only 1 value, now has {len(memory_use_values)} (memory_use_values)"
236
+ return memory_use_values[gpu_id]
237
+
238
+ class HardwareStatus():
239
+ def __init__(self):
240
+ self.gpu_mem = []
241
+ self.timer = None
242
+
243
+ def get_status_every_sec(self, gpu_id=0):
244
+ """
245
+ This function calls itself every 1 sec and appends the gpu_memory.
246
+ """
247
+ self.timer = Timer(1.0, self.get_status_every_sec)
248
+ self.timer.start()
249
+ self.gpu_mem.append(get_gpu_memory(gpu_id))
250
+ # print('self.gpu_mem',self.gpu_mem)
251
+
252
+ def stop_timer(self):
253
+ self.timer.cancel()
254
+
255
+