datamint 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datamint might be problematic. Click here for more details.
- datamint/__init__.py +11 -0
- datamint-1.2.4.dist-info/METADATA +118 -0
- datamint-1.2.4.dist-info/RECORD +30 -0
- datamint-1.2.4.dist-info/WHEEL +4 -0
- datamint-1.2.4.dist-info/entry_points.txt +4 -0
- datamintapi/__init__.py +25 -0
- datamintapi/apihandler/annotation_api_handler.py +748 -0
- datamintapi/apihandler/api_handler.py +15 -0
- datamintapi/apihandler/base_api_handler.py +300 -0
- datamintapi/apihandler/dto/annotation_dto.py +149 -0
- datamintapi/apihandler/exp_api_handler.py +204 -0
- datamintapi/apihandler/root_api_handler.py +1013 -0
- datamintapi/client_cmd_tools/__init__.py +0 -0
- datamintapi/client_cmd_tools/datamint_config.py +168 -0
- datamintapi/client_cmd_tools/datamint_upload.py +483 -0
- datamintapi/configs.py +58 -0
- datamintapi/dataset/__init__.py +1 -0
- datamintapi/dataset/base_dataset.py +881 -0
- datamintapi/dataset/dataset.py +492 -0
- datamintapi/examples/__init__.py +1 -0
- datamintapi/examples/example_projects.py +75 -0
- datamintapi/experiment/__init__.py +1 -0
- datamintapi/experiment/_patcher.py +570 -0
- datamintapi/experiment/experiment.py +1049 -0
- datamintapi/logging.yaml +27 -0
- datamintapi/utils/dicom_utils.py +640 -0
- datamintapi/utils/io_utils.py +149 -0
- datamintapi/utils/logging_utils.py +55 -0
- datamintapi/utils/torchmetrics.py +70 -0
- datamintapi/utils/visualization.py +129 -0
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import nibabel as nib
|
|
3
|
+
from PIL import Image
|
|
4
|
+
from .dicom_utils import load_image_normalized, is_dicom
|
|
5
|
+
import pydicom
|
|
6
|
+
import os
|
|
7
|
+
from typing import Any
|
|
8
|
+
import logging
|
|
9
|
+
from PIL import ImageFile
|
|
10
|
+
import cv2
|
|
11
|
+
|
|
12
|
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
13
|
+
|
|
14
|
+
_LOGGER = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
IMAGE_EXTS = ('.png', '.jpg', '.jpeg')
|
|
17
|
+
NII_EXTS = ('.nii', '.nii.gz')
|
|
18
|
+
VIDEO_EXTS = ('.mp4', '.avi', '.mov', '.mkv')
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def read_video(file_path: str, index: int = None) -> np.ndarray:
|
|
22
|
+
cap = cv2.VideoCapture(file_path)
|
|
23
|
+
if not cap.isOpened():
|
|
24
|
+
raise ValueError(f"Failed to open video file: {file_path}")
|
|
25
|
+
try:
|
|
26
|
+
if index is None:
|
|
27
|
+
frames = []
|
|
28
|
+
while True:
|
|
29
|
+
ret, frame = cap.read()
|
|
30
|
+
if not ret:
|
|
31
|
+
break
|
|
32
|
+
# Convert BGR to RGB and transpose to (C, H, W) format
|
|
33
|
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
34
|
+
frame = frame.transpose(2, 0, 1)
|
|
35
|
+
frames.append(frame)
|
|
36
|
+
imgs = np.array(frames) # shape: (#frames, C, H, W)
|
|
37
|
+
else:
|
|
38
|
+
while index > 0:
|
|
39
|
+
cap.grab()
|
|
40
|
+
index -= 1
|
|
41
|
+
ret, frame = cap.read()
|
|
42
|
+
if not ret:
|
|
43
|
+
raise ValueError(f"Failed to read frame {index} from video file: {file_path}")
|
|
44
|
+
# Convert BGR to RGB and transpose to (C, H, W) format
|
|
45
|
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
46
|
+
imgs = frame.transpose(2, 0, 1)
|
|
47
|
+
finally:
|
|
48
|
+
cap.release()
|
|
49
|
+
|
|
50
|
+
if imgs is None or len(imgs) == 0:
|
|
51
|
+
raise ValueError(f"No frames found in video file: {file_path}")
|
|
52
|
+
|
|
53
|
+
return imgs
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def read_nifti(file_path: str) -> np.ndarray:
|
|
57
|
+
imgs = nib.load(file_path).get_fdata() # shape: (W, H, #frame) or (W, H)
|
|
58
|
+
if imgs.ndim == 2:
|
|
59
|
+
imgs = imgs.transpose(1, 0)
|
|
60
|
+
imgs = imgs[np.newaxis, np.newaxis]
|
|
61
|
+
elif imgs.ndim == 3:
|
|
62
|
+
imgs = imgs.transpose(2, 1, 0)
|
|
63
|
+
imgs = imgs[:, np.newaxis]
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError(f"Unsupported number of dimensions in '{file_path}': {imgs.ndim}")
|
|
66
|
+
|
|
67
|
+
return imgs
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def read_image(file_path: str) -> np.ndarray:
|
|
71
|
+
with Image.open(file_path) as pilimg:
|
|
72
|
+
imgs = np.array(pilimg)
|
|
73
|
+
if imgs.ndim == 2: # (H, W)
|
|
74
|
+
imgs = imgs[np.newaxis, np.newaxis]
|
|
75
|
+
elif imgs.ndim == 3: # (H, W, C)
|
|
76
|
+
imgs = imgs.transpose(2, 0, 1)[np.newaxis] # (H, W, C) -> (1, C, H, W)
|
|
77
|
+
|
|
78
|
+
return imgs
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def read_array_normalized(file_path: str,
|
|
82
|
+
index: int = None,
|
|
83
|
+
return_metainfo: bool = False,
|
|
84
|
+
use_magic=False) -> np.ndarray | tuple[np.ndarray, Any]:
|
|
85
|
+
"""
|
|
86
|
+
Read an array from a file.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
file_path: The path to the file.
|
|
90
|
+
Supported file formats are NIfTI (.nii, .nii.gz), PNG (.png), JPEG (.jpg, .jpeg) and npy (.npy).
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
The array read from the file in shape (#frames, C, H, W), if `index=None`,
|
|
94
|
+
or (C, H, W) if `index` is specified.
|
|
95
|
+
"""
|
|
96
|
+
if not os.path.exists(file_path):
|
|
97
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
98
|
+
|
|
99
|
+
metainfo = None
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
if is_dicom(file_path):
|
|
103
|
+
ds = pydicom.dcmread(file_path)
|
|
104
|
+
if index is not None:
|
|
105
|
+
imgs = load_image_normalized(ds, index=index)[0]
|
|
106
|
+
else:
|
|
107
|
+
imgs = load_image_normalized(ds)
|
|
108
|
+
# Free up memory
|
|
109
|
+
if hasattr(ds, '_pixel_array'):
|
|
110
|
+
ds._pixel_array = None
|
|
111
|
+
if hasattr(ds, 'PixelData'):
|
|
112
|
+
ds.PixelData = None
|
|
113
|
+
metainfo = ds
|
|
114
|
+
else:
|
|
115
|
+
if use_magic:
|
|
116
|
+
import magic # it is important to import here because magic has an OS lib dependency.
|
|
117
|
+
mime_type = magic.from_file(file_path, mime=True)
|
|
118
|
+
else:
|
|
119
|
+
mime_type = ""
|
|
120
|
+
|
|
121
|
+
if mime_type.startswith('video/') or file_path.endswith(VIDEO_EXTS):
|
|
122
|
+
imgs = read_video(file_path, index)
|
|
123
|
+
else:
|
|
124
|
+
if mime_type == 'image/x.nifti' or file_path.endswith(NII_EXTS):
|
|
125
|
+
imgs = read_nifti(file_path)
|
|
126
|
+
elif mime_type.startswith('image/') or file_path.endswith(IMAGE_EXTS):
|
|
127
|
+
imgs = read_image(file_path)
|
|
128
|
+
elif file_path.endswith('.npy') or mime_type == 'application/x-numpy-data':
|
|
129
|
+
imgs = np.load(file_path)
|
|
130
|
+
if imgs.ndim != 4:
|
|
131
|
+
raise ValueError(f"Unsupported number of dimensions in '{file_path}': {imgs.ndim}")
|
|
132
|
+
else:
|
|
133
|
+
raise ValueError(f"Unsupported file format '{mime_type}' of '{file_path}'")
|
|
134
|
+
|
|
135
|
+
if index is not None:
|
|
136
|
+
if len(imgs) > 1:
|
|
137
|
+
_LOGGER.warning(f"It is inefficient to load all frames from '{file_path}' to access a single frame." +
|
|
138
|
+
" Consider converting the file to a format that supports random access (DICOM), or" +
|
|
139
|
+
" convert to png/jpeg files or" +
|
|
140
|
+
" manually handle all frames at once instead of loading a specific frame.")
|
|
141
|
+
imgs = imgs[index]
|
|
142
|
+
|
|
143
|
+
if return_metainfo:
|
|
144
|
+
return imgs, metainfo
|
|
145
|
+
return imgs
|
|
146
|
+
|
|
147
|
+
except Exception as e:
|
|
148
|
+
_LOGGER.error(f"Failed to read array from '{file_path}': {e}")
|
|
149
|
+
raise e
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import logging.config
|
|
3
|
+
from rich.console import ConsoleRenderable
|
|
4
|
+
from rich.logging import RichHandler
|
|
5
|
+
from rich.traceback import Traceback
|
|
6
|
+
import yaml
|
|
7
|
+
import importlib
|
|
8
|
+
|
|
9
|
+
_LOGGER = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConditionalRichHandler(RichHandler):
|
|
13
|
+
"""
|
|
14
|
+
Class that uses 'show_level=True' only if the message level is WARNING or higher.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, *args, **kwargs):
|
|
18
|
+
super().__init__(*args, **kwargs)
|
|
19
|
+
|
|
20
|
+
def handle(self, record):
|
|
21
|
+
if record.levelno >= logging.WARNING:
|
|
22
|
+
self.show_level = True
|
|
23
|
+
else:
|
|
24
|
+
self.show_level = False
|
|
25
|
+
super().handle(record)
|
|
26
|
+
|
|
27
|
+
def render(self, *, record: logging.LogRecord,
|
|
28
|
+
traceback: Traceback | None,
|
|
29
|
+
message_renderable: ConsoleRenderable) -> ConsoleRenderable:
|
|
30
|
+
# if level is WARNING or higher, add the level column
|
|
31
|
+
try:
|
|
32
|
+
self._log_render.show_level = record.levelno >= logging.WARNING
|
|
33
|
+
ret = super().render(record=record, traceback=traceback, message_renderable=message_renderable)
|
|
34
|
+
self._log_render.show_level = False
|
|
35
|
+
except Exception as e:
|
|
36
|
+
_LOGGER.error(f"Error rendering log. {e}")
|
|
37
|
+
return ret
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def load_cmdline_logging_config():
|
|
41
|
+
# Load the logging configuration file
|
|
42
|
+
try:
|
|
43
|
+
try:
|
|
44
|
+
# try loading the developer's logging config
|
|
45
|
+
with open('logging_dev.yaml', 'r') as f:
|
|
46
|
+
config = yaml.safe_load(f)
|
|
47
|
+
except:
|
|
48
|
+
with importlib.resources.open_text('datamintapi', 'logging.yaml') as f:
|
|
49
|
+
config = yaml.safe_load(f.read())
|
|
50
|
+
|
|
51
|
+
logging.config.dictConfig(config)
|
|
52
|
+
except Exception as e:
|
|
53
|
+
print(f"Warning: Error loading logging configuration file: {e}")
|
|
54
|
+
_LOGGER.exception(e)
|
|
55
|
+
logging.basicConfig(level=logging.INFO)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from torchmetrics.classification import Recall, Precision, F1Score, Specificity
|
|
2
|
+
from torchmetrics.wrappers.abstract import WrapperMetric
|
|
3
|
+
import torchmetrics
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SegmentationToClassificationWrapper(WrapperMetric):
|
|
8
|
+
"""
|
|
9
|
+
Enables applying classification metrics to segmentation masks.
|
|
10
|
+
The segmentation masks are converted to binary one-hot enconding vectors using a IoU threshold.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
metric_cls: Segmentation metric class.
|
|
14
|
+
iou_threshold: IoU threshold to convert segmentation to binary classification.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self,
|
|
18
|
+
metric_cls: torchmetrics.Metric,
|
|
19
|
+
iou_threshold=0.5):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.metric = metric_cls
|
|
22
|
+
self.iou_threshold = iou_threshold
|
|
23
|
+
|
|
24
|
+
def update(self, preds: Tensor, target: Tensor):
|
|
25
|
+
cls_pred, cls_target = self.transform_mask_to_binary(preds, target, self.iou_threshold)
|
|
26
|
+
self.metric.update(cls_pred, cls_target)
|
|
27
|
+
|
|
28
|
+
def compute(self):
|
|
29
|
+
return self.metric.compute()
|
|
30
|
+
|
|
31
|
+
def reset(self):
|
|
32
|
+
super().reset()
|
|
33
|
+
self.metric.reset()
|
|
34
|
+
|
|
35
|
+
def forward(self, preds: Tensor, target: Tensor):
|
|
36
|
+
cls_pred, cls_target = self.transform_mask_to_binary(preds, target, self.iou_threshold)
|
|
37
|
+
return self.metric.forward(cls_pred, cls_target)
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def transform_mask_to_binary(pred: Tensor, target: Tensor, iou_threshold: float = 0.5) -> tuple[Tensor, Tensor]:
|
|
41
|
+
"""
|
|
42
|
+
Convert both the segmentation masks with shape (B,L,H,W) to shape (B,L), so that classification metrics can be used.
|
|
43
|
+
The conversion is done by applying using a IoU threshold that must be satisfied to consider a prediction as true positive.
|
|
44
|
+
Args:
|
|
45
|
+
pred: Segmentation prediction of shape (B, L, H, W).
|
|
46
|
+
target: Segmentation target of shape (B, L, H, W).
|
|
47
|
+
iou_threshold: IoU threshold to convert segmentation to binary classification.
|
|
48
|
+
Returns:
|
|
49
|
+
Tuple of binary classification predictions (first) and targets (second) of shape (B, L).
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
if pred.ndim != 4 or target.ndim != 4:
|
|
53
|
+
raise ValueError("Input tensors must have 4 dimensions (B, L, H, W).")
|
|
54
|
+
|
|
55
|
+
# Calculate IoU for each sample
|
|
56
|
+
intersection = (pred & target).float().sum(dim=(2, 3))
|
|
57
|
+
union = (pred | target).float().sum(dim=(2, 3))
|
|
58
|
+
iou = intersection / union
|
|
59
|
+
|
|
60
|
+
cls_target = target.amax(dim=(2, 3))
|
|
61
|
+
cls_pred = (iou >= iou_threshold).float()
|
|
62
|
+
|
|
63
|
+
return cls_pred, cls_target
|
|
64
|
+
|
|
65
|
+
# # test
|
|
66
|
+
# cls_metric = Recall(task="multilabel", average="macro", num_labels=3)
|
|
67
|
+
# metric = SegmentationToClassificationWrapper(cls_metric, iou_threshold=0.33)
|
|
68
|
+
# metric(preds=torch.randint(0, 2, size=(4, 3, 32, 32)),
|
|
69
|
+
# target=torch.randint(0, 2, size=(4, 3, 32, 32))
|
|
70
|
+
# )
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import matplotlib.pyplot as plt
|
|
2
|
+
import numpy as np
|
|
3
|
+
from torchvision.transforms import functional as F
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
import torchvision.utils
|
|
6
|
+
import torch
|
|
7
|
+
import colorsys
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def show(imgs: list[Tensor] | Tensor,
|
|
11
|
+
figsize: tuple[int, int] = None,
|
|
12
|
+
normalize: bool = False):
|
|
13
|
+
"""
|
|
14
|
+
Show a list of images in a grid.
|
|
15
|
+
Args:
|
|
16
|
+
imgs (list[Tensor] | Tensor): List of images to show.
|
|
17
|
+
Each image should be a tensor of shape (C, H, W) and dtype uint8 or float.
|
|
18
|
+
figsize (tuple[int, int], optional): Size of the figure. Defaults to None.
|
|
19
|
+
normalize (bool, optional): Whether to normalize the images to [0, 1] range by min-max scaling.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
if not isinstance(imgs, list):
|
|
23
|
+
imgs = [imgs]
|
|
24
|
+
|
|
25
|
+
if normalize:
|
|
26
|
+
for i, img in enumerate(imgs):
|
|
27
|
+
img = img.float()
|
|
28
|
+
img = img - img.min()
|
|
29
|
+
img = img / img.max()
|
|
30
|
+
imgs[i] = img
|
|
31
|
+
|
|
32
|
+
if figsize is not None:
|
|
33
|
+
fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=figsize)
|
|
34
|
+
else:
|
|
35
|
+
fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
|
|
36
|
+
for i, img in enumerate(imgs):
|
|
37
|
+
img = img.detach()
|
|
38
|
+
img = F.to_pil_image(img)
|
|
39
|
+
axs[0, i].imshow(np.asarray(img))
|
|
40
|
+
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
_COLORS = [(255, 0, 0), (0, 255, 0), (0, 0, 255),
|
|
44
|
+
(200, 128, 0), (128, 200, 0), (0, 128, 200), (200, 0, 128), (128, 0, 200), (0, 200, 128)]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def generate_color_palette(num_objects: int) -> list[tuple[int, int, int]]:
|
|
48
|
+
"""
|
|
49
|
+
Generate a list of colors for segmentation masks.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
num_objects (int): Number of objects to generate colors for.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
List of RGB colors.
|
|
56
|
+
"""
|
|
57
|
+
if num_objects <= 0:
|
|
58
|
+
raise ValueError("Number of objects must be greater than 0.")
|
|
59
|
+
|
|
60
|
+
colors = _COLORS[:num_objects]
|
|
61
|
+
if len(colors) == num_objects:
|
|
62
|
+
return colors
|
|
63
|
+
|
|
64
|
+
num_objects -= len(colors)
|
|
65
|
+
|
|
66
|
+
# generate random colors
|
|
67
|
+
for _ in range(num_objects):
|
|
68
|
+
hue = np.random.rand()
|
|
69
|
+
rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.9)
|
|
70
|
+
colors.append(tuple(int(c * 255) for c in rgb))
|
|
71
|
+
|
|
72
|
+
return colors
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@torch.no_grad()
|
|
76
|
+
def draw_masks(
|
|
77
|
+
image: Tensor,
|
|
78
|
+
masks: Tensor,
|
|
79
|
+
alpha: float = 0.5,
|
|
80
|
+
colors: list[str | tuple[int, int, int]] | str | tuple[int, int, int] | None = None,
|
|
81
|
+
) -> Tensor:
|
|
82
|
+
"""
|
|
83
|
+
Draws segmentation masks on given RGB image.
|
|
84
|
+
This is different from `torchvision.utils.draw_segmentation_masks` as overlapping masks are blended together correctly.
|
|
85
|
+
The image values should be uint8 or float.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
|
|
89
|
+
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
|
|
90
|
+
alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
|
|
91
|
+
0 means full transparency, 1 means no transparency.
|
|
92
|
+
colors (color or list of colors, optional): List containing the colors
|
|
93
|
+
of the masks or single color for all masks. The color can be represented as
|
|
94
|
+
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
|
|
95
|
+
By default, random colors are generated for each mask.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
if image.ndim == 3 and image.shape[0] == 1:
|
|
103
|
+
# convert to RGB
|
|
104
|
+
image = image.expand(3, -1, -1)
|
|
105
|
+
|
|
106
|
+
if masks.dtype != torch.bool:
|
|
107
|
+
masks = masks.bool()
|
|
108
|
+
|
|
109
|
+
# if image has negative values, scale to [0, 1]
|
|
110
|
+
if image.min() < 0:
|
|
111
|
+
image = image.float()
|
|
112
|
+
image = image - image.min()
|
|
113
|
+
image = image / image.max()
|
|
114
|
+
|
|
115
|
+
if masks.ndim == 2:
|
|
116
|
+
return torchvision.utils.draw_segmentation_masks(image=image,
|
|
117
|
+
masks=masks,
|
|
118
|
+
alpha=alpha,
|
|
119
|
+
colors=colors)
|
|
120
|
+
|
|
121
|
+
if colors is None:
|
|
122
|
+
colors = generate_color_palette(len(masks))
|
|
123
|
+
|
|
124
|
+
for color, mask in zip(colors, masks):
|
|
125
|
+
image = torchvision.utils.draw_segmentation_masks(image=image,
|
|
126
|
+
masks=mask,
|
|
127
|
+
alpha=alpha,
|
|
128
|
+
colors=color)
|
|
129
|
+
return image
|