datamint 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datamint might be problematic. Click here for more details.

@@ -0,0 +1,149 @@
1
+ import numpy as np
2
+ import nibabel as nib
3
+ from PIL import Image
4
+ from .dicom_utils import load_image_normalized, is_dicom
5
+ import pydicom
6
+ import os
7
+ from typing import Any
8
+ import logging
9
+ from PIL import ImageFile
10
+ import cv2
11
+
12
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
13
+
14
+ _LOGGER = logging.getLogger(__name__)
15
+
16
+ IMAGE_EXTS = ('.png', '.jpg', '.jpeg')
17
+ NII_EXTS = ('.nii', '.nii.gz')
18
+ VIDEO_EXTS = ('.mp4', '.avi', '.mov', '.mkv')
19
+
20
+
21
+ def read_video(file_path: str, index: int = None) -> np.ndarray:
22
+ cap = cv2.VideoCapture(file_path)
23
+ if not cap.isOpened():
24
+ raise ValueError(f"Failed to open video file: {file_path}")
25
+ try:
26
+ if index is None:
27
+ frames = []
28
+ while True:
29
+ ret, frame = cap.read()
30
+ if not ret:
31
+ break
32
+ # Convert BGR to RGB and transpose to (C, H, W) format
33
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
34
+ frame = frame.transpose(2, 0, 1)
35
+ frames.append(frame)
36
+ imgs = np.array(frames) # shape: (#frames, C, H, W)
37
+ else:
38
+ while index > 0:
39
+ cap.grab()
40
+ index -= 1
41
+ ret, frame = cap.read()
42
+ if not ret:
43
+ raise ValueError(f"Failed to read frame {index} from video file: {file_path}")
44
+ # Convert BGR to RGB and transpose to (C, H, W) format
45
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
46
+ imgs = frame.transpose(2, 0, 1)
47
+ finally:
48
+ cap.release()
49
+
50
+ if imgs is None or len(imgs) == 0:
51
+ raise ValueError(f"No frames found in video file: {file_path}")
52
+
53
+ return imgs
54
+
55
+
56
+ def read_nifti(file_path: str) -> np.ndarray:
57
+ imgs = nib.load(file_path).get_fdata() # shape: (W, H, #frame) or (W, H)
58
+ if imgs.ndim == 2:
59
+ imgs = imgs.transpose(1, 0)
60
+ imgs = imgs[np.newaxis, np.newaxis]
61
+ elif imgs.ndim == 3:
62
+ imgs = imgs.transpose(2, 1, 0)
63
+ imgs = imgs[:, np.newaxis]
64
+ else:
65
+ raise ValueError(f"Unsupported number of dimensions in '{file_path}': {imgs.ndim}")
66
+
67
+ return imgs
68
+
69
+
70
+ def read_image(file_path: str) -> np.ndarray:
71
+ with Image.open(file_path) as pilimg:
72
+ imgs = np.array(pilimg)
73
+ if imgs.ndim == 2: # (H, W)
74
+ imgs = imgs[np.newaxis, np.newaxis]
75
+ elif imgs.ndim == 3: # (H, W, C)
76
+ imgs = imgs.transpose(2, 0, 1)[np.newaxis] # (H, W, C) -> (1, C, H, W)
77
+
78
+ return imgs
79
+
80
+
81
+ def read_array_normalized(file_path: str,
82
+ index: int = None,
83
+ return_metainfo: bool = False,
84
+ use_magic=False) -> np.ndarray | tuple[np.ndarray, Any]:
85
+ """
86
+ Read an array from a file.
87
+
88
+ Args:
89
+ file_path: The path to the file.
90
+ Supported file formats are NIfTI (.nii, .nii.gz), PNG (.png), JPEG (.jpg, .jpeg) and npy (.npy).
91
+
92
+ Returns:
93
+ The array read from the file in shape (#frames, C, H, W), if `index=None`,
94
+ or (C, H, W) if `index` is specified.
95
+ """
96
+ if not os.path.exists(file_path):
97
+ raise FileNotFoundError(f"File not found: {file_path}")
98
+
99
+ metainfo = None
100
+
101
+ try:
102
+ if is_dicom(file_path):
103
+ ds = pydicom.dcmread(file_path)
104
+ if index is not None:
105
+ imgs = load_image_normalized(ds, index=index)[0]
106
+ else:
107
+ imgs = load_image_normalized(ds)
108
+ # Free up memory
109
+ if hasattr(ds, '_pixel_array'):
110
+ ds._pixel_array = None
111
+ if hasattr(ds, 'PixelData'):
112
+ ds.PixelData = None
113
+ metainfo = ds
114
+ else:
115
+ if use_magic:
116
+ import magic # it is important to import here because magic has an OS lib dependency.
117
+ mime_type = magic.from_file(file_path, mime=True)
118
+ else:
119
+ mime_type = ""
120
+
121
+ if mime_type.startswith('video/') or file_path.endswith(VIDEO_EXTS):
122
+ imgs = read_video(file_path, index)
123
+ else:
124
+ if mime_type == 'image/x.nifti' or file_path.endswith(NII_EXTS):
125
+ imgs = read_nifti(file_path)
126
+ elif mime_type.startswith('image/') or file_path.endswith(IMAGE_EXTS):
127
+ imgs = read_image(file_path)
128
+ elif file_path.endswith('.npy') or mime_type == 'application/x-numpy-data':
129
+ imgs = np.load(file_path)
130
+ if imgs.ndim != 4:
131
+ raise ValueError(f"Unsupported number of dimensions in '{file_path}': {imgs.ndim}")
132
+ else:
133
+ raise ValueError(f"Unsupported file format '{mime_type}' of '{file_path}'")
134
+
135
+ if index is not None:
136
+ if len(imgs) > 1:
137
+ _LOGGER.warning(f"It is inefficient to load all frames from '{file_path}' to access a single frame." +
138
+ " Consider converting the file to a format that supports random access (DICOM), or" +
139
+ " convert to png/jpeg files or" +
140
+ " manually handle all frames at once instead of loading a specific frame.")
141
+ imgs = imgs[index]
142
+
143
+ if return_metainfo:
144
+ return imgs, metainfo
145
+ return imgs
146
+
147
+ except Exception as e:
148
+ _LOGGER.error(f"Failed to read array from '{file_path}': {e}")
149
+ raise e
@@ -0,0 +1,55 @@
1
+ import logging
2
+ import logging.config
3
+ from rich.console import ConsoleRenderable
4
+ from rich.logging import RichHandler
5
+ from rich.traceback import Traceback
6
+ import yaml
7
+ import importlib
8
+
9
+ _LOGGER = logging.getLogger(__name__)
10
+
11
+
12
+ class ConditionalRichHandler(RichHandler):
13
+ """
14
+ Class that uses 'show_level=True' only if the message level is WARNING or higher.
15
+ """
16
+
17
+ def __init__(self, *args, **kwargs):
18
+ super().__init__(*args, **kwargs)
19
+
20
+ def handle(self, record):
21
+ if record.levelno >= logging.WARNING:
22
+ self.show_level = True
23
+ else:
24
+ self.show_level = False
25
+ super().handle(record)
26
+
27
+ def render(self, *, record: logging.LogRecord,
28
+ traceback: Traceback | None,
29
+ message_renderable: ConsoleRenderable) -> ConsoleRenderable:
30
+ # if level is WARNING or higher, add the level column
31
+ try:
32
+ self._log_render.show_level = record.levelno >= logging.WARNING
33
+ ret = super().render(record=record, traceback=traceback, message_renderable=message_renderable)
34
+ self._log_render.show_level = False
35
+ except Exception as e:
36
+ _LOGGER.error(f"Error rendering log. {e}")
37
+ return ret
38
+
39
+
40
+ def load_cmdline_logging_config():
41
+ # Load the logging configuration file
42
+ try:
43
+ try:
44
+ # try loading the developer's logging config
45
+ with open('logging_dev.yaml', 'r') as f:
46
+ config = yaml.safe_load(f)
47
+ except:
48
+ with importlib.resources.open_text('datamintapi', 'logging.yaml') as f:
49
+ config = yaml.safe_load(f.read())
50
+
51
+ logging.config.dictConfig(config)
52
+ except Exception as e:
53
+ print(f"Warning: Error loading logging configuration file: {e}")
54
+ _LOGGER.exception(e)
55
+ logging.basicConfig(level=logging.INFO)
@@ -0,0 +1,70 @@
1
+ from torchmetrics.classification import Recall, Precision, F1Score, Specificity
2
+ from torchmetrics.wrappers.abstract import WrapperMetric
3
+ import torchmetrics
4
+ from torch import Tensor
5
+
6
+
7
+ class SegmentationToClassificationWrapper(WrapperMetric):
8
+ """
9
+ Enables applying classification metrics to segmentation masks.
10
+ The segmentation masks are converted to binary one-hot enconding vectors using a IoU threshold.
11
+
12
+ Args:
13
+ metric_cls: Segmentation metric class.
14
+ iou_threshold: IoU threshold to convert segmentation to binary classification.
15
+ """
16
+
17
+ def __init__(self,
18
+ metric_cls: torchmetrics.Metric,
19
+ iou_threshold=0.5):
20
+ super().__init__()
21
+ self.metric = metric_cls
22
+ self.iou_threshold = iou_threshold
23
+
24
+ def update(self, preds: Tensor, target: Tensor):
25
+ cls_pred, cls_target = self.transform_mask_to_binary(preds, target, self.iou_threshold)
26
+ self.metric.update(cls_pred, cls_target)
27
+
28
+ def compute(self):
29
+ return self.metric.compute()
30
+
31
+ def reset(self):
32
+ super().reset()
33
+ self.metric.reset()
34
+
35
+ def forward(self, preds: Tensor, target: Tensor):
36
+ cls_pred, cls_target = self.transform_mask_to_binary(preds, target, self.iou_threshold)
37
+ return self.metric.forward(cls_pred, cls_target)
38
+
39
+ @staticmethod
40
+ def transform_mask_to_binary(pred: Tensor, target: Tensor, iou_threshold: float = 0.5) -> tuple[Tensor, Tensor]:
41
+ """
42
+ Convert both the segmentation masks with shape (B,L,H,W) to shape (B,L), so that classification metrics can be used.
43
+ The conversion is done by applying using a IoU threshold that must be satisfied to consider a prediction as true positive.
44
+ Args:
45
+ pred: Segmentation prediction of shape (B, L, H, W).
46
+ target: Segmentation target of shape (B, L, H, W).
47
+ iou_threshold: IoU threshold to convert segmentation to binary classification.
48
+ Returns:
49
+ Tuple of binary classification predictions (first) and targets (second) of shape (B, L).
50
+ """
51
+
52
+ if pred.ndim != 4 or target.ndim != 4:
53
+ raise ValueError("Input tensors must have 4 dimensions (B, L, H, W).")
54
+
55
+ # Calculate IoU for each sample
56
+ intersection = (pred & target).float().sum(dim=(2, 3))
57
+ union = (pred | target).float().sum(dim=(2, 3))
58
+ iou = intersection / union
59
+
60
+ cls_target = target.amax(dim=(2, 3))
61
+ cls_pred = (iou >= iou_threshold).float()
62
+
63
+ return cls_pred, cls_target
64
+
65
+ # # test
66
+ # cls_metric = Recall(task="multilabel", average="macro", num_labels=3)
67
+ # metric = SegmentationToClassificationWrapper(cls_metric, iou_threshold=0.33)
68
+ # metric(preds=torch.randint(0, 2, size=(4, 3, 32, 32)),
69
+ # target=torch.randint(0, 2, size=(4, 3, 32, 32))
70
+ # )
@@ -0,0 +1,129 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from torchvision.transforms import functional as F
4
+ from torch import Tensor
5
+ import torchvision.utils
6
+ import torch
7
+ import colorsys
8
+
9
+
10
+ def show(imgs: list[Tensor] | Tensor,
11
+ figsize: tuple[int, int] = None,
12
+ normalize: bool = False):
13
+ """
14
+ Show a list of images in a grid.
15
+ Args:
16
+ imgs (list[Tensor] | Tensor): List of images to show.
17
+ Each image should be a tensor of shape (C, H, W) and dtype uint8 or float.
18
+ figsize (tuple[int, int], optional): Size of the figure. Defaults to None.
19
+ normalize (bool, optional): Whether to normalize the images to [0, 1] range by min-max scaling.
20
+ """
21
+
22
+ if not isinstance(imgs, list):
23
+ imgs = [imgs]
24
+
25
+ if normalize:
26
+ for i, img in enumerate(imgs):
27
+ img = img.float()
28
+ img = img - img.min()
29
+ img = img / img.max()
30
+ imgs[i] = img
31
+
32
+ if figsize is not None:
33
+ fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=figsize)
34
+ else:
35
+ fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
36
+ for i, img in enumerate(imgs):
37
+ img = img.detach()
38
+ img = F.to_pil_image(img)
39
+ axs[0, i].imshow(np.asarray(img))
40
+ axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
41
+
42
+
43
+ _COLORS = [(255, 0, 0), (0, 255, 0), (0, 0, 255),
44
+ (200, 128, 0), (128, 200, 0), (0, 128, 200), (200, 0, 128), (128, 0, 200), (0, 200, 128)]
45
+
46
+
47
+ def generate_color_palette(num_objects: int) -> list[tuple[int, int, int]]:
48
+ """
49
+ Generate a list of colors for segmentation masks.
50
+
51
+ Args:
52
+ num_objects (int): Number of objects to generate colors for.
53
+
54
+ Returns:
55
+ List of RGB colors.
56
+ """
57
+ if num_objects <= 0:
58
+ raise ValueError("Number of objects must be greater than 0.")
59
+
60
+ colors = _COLORS[:num_objects]
61
+ if len(colors) == num_objects:
62
+ return colors
63
+
64
+ num_objects -= len(colors)
65
+
66
+ # generate random colors
67
+ for _ in range(num_objects):
68
+ hue = np.random.rand()
69
+ rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.9)
70
+ colors.append(tuple(int(c * 255) for c in rgb))
71
+
72
+ return colors
73
+
74
+
75
+ @torch.no_grad()
76
+ def draw_masks(
77
+ image: Tensor,
78
+ masks: Tensor,
79
+ alpha: float = 0.5,
80
+ colors: list[str | tuple[int, int, int]] | str | tuple[int, int, int] | None = None,
81
+ ) -> Tensor:
82
+ """
83
+ Draws segmentation masks on given RGB image.
84
+ This is different from `torchvision.utils.draw_segmentation_masks` as overlapping masks are blended together correctly.
85
+ The image values should be uint8 or float.
86
+
87
+ Args:
88
+ image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
89
+ masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
90
+ alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
91
+ 0 means full transparency, 1 means no transparency.
92
+ colors (color or list of colors, optional): List containing the colors
93
+ of the masks or single color for all masks. The color can be represented as
94
+ PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
95
+ By default, random colors are generated for each mask.
96
+
97
+ Returns:
98
+ img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
99
+ """
100
+
101
+
102
+ if image.ndim == 3 and image.shape[0] == 1:
103
+ # convert to RGB
104
+ image = image.expand(3, -1, -1)
105
+
106
+ if masks.dtype != torch.bool:
107
+ masks = masks.bool()
108
+
109
+ # if image has negative values, scale to [0, 1]
110
+ if image.min() < 0:
111
+ image = image.float()
112
+ image = image - image.min()
113
+ image = image / image.max()
114
+
115
+ if masks.ndim == 2:
116
+ return torchvision.utils.draw_segmentation_masks(image=image,
117
+ masks=masks,
118
+ alpha=alpha,
119
+ colors=colors)
120
+
121
+ if colors is None:
122
+ colors = generate_color_palette(len(masks))
123
+
124
+ for color, mask in zip(colors, masks):
125
+ image = torchvision.utils.draw_segmentation_masks(image=image,
126
+ masks=mask,
127
+ alpha=alpha,
128
+ colors=color)
129
+ return image