nnInteractive 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nnInteractive/__init__.py +3 -0
- nnInteractive/inference/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
- nnInteractive/inference/inference_session.py +1400 -0
- nnInteractive/interaction/__init__.py +0 -0
- nnInteractive/interaction/point.py +166 -0
- nnInteractive/supervoxel/setup.py +4 -0
- nnInteractive/supervoxel/src/metadata.py +118 -0
- nnInteractive/supervoxel/src/reader.py +175 -0
- nnInteractive/supervoxel/src/run.py +136 -0
- nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
- nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
- nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
- nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
- nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
- nnInteractive/supervoxel/src/sam2/setup.py +174 -0
- nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
- nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
- nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
- nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
- nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
- nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
- nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
- nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
- nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
- nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
- nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
- nnInteractive/supervoxel/src/supervoxel.py +198 -0
- nnInteractive/trainer/__init__.py +0 -0
- nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
- nnInteractive/utils/__init__.py +0 -0
- nnInteractive/utils/bboxes.py +217 -0
- nnInteractive/utils/checkpoint_cleansing.py +9 -0
- nnInteractive/utils/crop.py +268 -0
- nnInteractive/utils/erosion_dilation.py +48 -0
- nnInteractive/utils/inference_helpers.py +45 -0
- nnInteractive/utils/os_shennanigans.py +16 -0
- nnInteractive/utils/rounding.py +13 -0
- nninteractive-2.0.0.dist-info/METADATA +511 -0
- nninteractive-2.0.0.dist-info/RECORD +76 -0
- nninteractive-2.0.0.dist-info/WHEEL +5 -0
- nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
- nninteractive-2.0.0.dist-info/top_level.txt +1 -0
|
File without changes
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
from functools import lru_cache
|
|
2
|
+
from typing import Tuple, Optional
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from batchgeneratorsv2.helpers.scalar_type import sample_scalar, RandomScalar
|
|
7
|
+
from scipy.ndimage import distance_transform_edt
|
|
8
|
+
from skimage.morphology import disk, ball
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@lru_cache(maxsize=5)
|
|
12
|
+
def build_point(radii, use_distance_transform, binarize):
|
|
13
|
+
max_radius = max(radii)
|
|
14
|
+
ndim = len(radii)
|
|
15
|
+
|
|
16
|
+
# Create a spherical (or circular) structuring element with max_radius
|
|
17
|
+
if ndim == 2:
|
|
18
|
+
structuring_element = disk(max_radius)
|
|
19
|
+
elif ndim == 3:
|
|
20
|
+
structuring_element = ball(max_radius)
|
|
21
|
+
else:
|
|
22
|
+
raise ValueError("Unsupported number of dimensions. Only 2D and 3D are supported.")
|
|
23
|
+
|
|
24
|
+
# Convert the structuring element to a tensor
|
|
25
|
+
structuring_element = torch.from_numpy(structuring_element.astype(np.float32))
|
|
26
|
+
|
|
27
|
+
# Create the target shape based on the sampled radii
|
|
28
|
+
target_shape = [round(2 * r + 1) for r in radii]
|
|
29
|
+
|
|
30
|
+
if any([i != j for i, j in zip(target_shape, structuring_element.shape)]):
|
|
31
|
+
structuring_element_resized = torch.nn.functional.interpolate(
|
|
32
|
+
structuring_element.unsqueeze(0).unsqueeze(0), # Add batch and channel dimensions for interpolation
|
|
33
|
+
size=target_shape,
|
|
34
|
+
mode="trilinear" if ndim == 3 else "bilinear",
|
|
35
|
+
align_corners=False,
|
|
36
|
+
)[
|
|
37
|
+
0, 0
|
|
38
|
+
] # Remove batch and channel dimensions after interpolation
|
|
39
|
+
else:
|
|
40
|
+
structuring_element_resized = structuring_element
|
|
41
|
+
|
|
42
|
+
if use_distance_transform:
|
|
43
|
+
# Convert the structuring element to a binary mask for distance transform computation
|
|
44
|
+
binary_structuring_element = (structuring_element_resized >= 0.5).numpy()
|
|
45
|
+
|
|
46
|
+
# Compute the Euclidean distance transform of the binary structuring element
|
|
47
|
+
structuring_element_resized = distance_transform_edt(binary_structuring_element)
|
|
48
|
+
|
|
49
|
+
# Normalize the distance transform to have values between 0 and 1
|
|
50
|
+
structuring_element_resized /= structuring_element_resized.max()
|
|
51
|
+
structuring_element_resized = torch.from_numpy(structuring_element_resized)
|
|
52
|
+
|
|
53
|
+
if binarize and not use_distance_transform:
|
|
54
|
+
# Normalize the resized structuring element to binary (values near 1 are treated as the point region)
|
|
55
|
+
structuring_element_resized = (structuring_element_resized >= 0.5).float()
|
|
56
|
+
return structuring_element_resized
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class PointInteraction_stub:
|
|
60
|
+
interaction_type = "point"
|
|
61
|
+
|
|
62
|
+
def __init__(self, point_radius: RandomScalar, use_distance_transform: bool = False):
|
|
63
|
+
"""
|
|
64
|
+
Initializes the PointInteraction object.
|
|
65
|
+
|
|
66
|
+
Parameters:
|
|
67
|
+
point_radius (RandomScalar): Specifies the radius for the interaction points.
|
|
68
|
+
use_distance_transform (bool): Determines whether to use a distance transform for smooth interactions.
|
|
69
|
+
"""
|
|
70
|
+
super().__init__()
|
|
71
|
+
self.point_radius = point_radius
|
|
72
|
+
self.use_distance_transform = use_distance_transform
|
|
73
|
+
|
|
74
|
+
def place_point(
|
|
75
|
+
self,
|
|
76
|
+
position: Tuple[int, ...],
|
|
77
|
+
interaction_map,
|
|
78
|
+
binarize: bool = False,
|
|
79
|
+
intensity_scale: float = 1.0,
|
|
80
|
+
channel_idx: Optional[int] = None,
|
|
81
|
+
) -> torch.Tensor:
|
|
82
|
+
"""
|
|
83
|
+
Places a point on the interaction map around the specified position.
|
|
84
|
+
|
|
85
|
+
Parameters:
|
|
86
|
+
position (Tuple[int, ...]): The (x, y, z) coordinates where the point should be placed.
|
|
87
|
+
interaction_map: A tensor (or blosc2 NDArray when channel_idx is provided) representing
|
|
88
|
+
the interaction map where the point should be placed.
|
|
89
|
+
binarize (bool): If True, inserts a binary mask. If False, may insert smooth values based on distance.
|
|
90
|
+
intensity_scale (float): Scale factor applied to the structuring element values.
|
|
91
|
+
channel_idx (int, optional): If provided, interaction_map is treated as a 4D blosc2 NDArray
|
|
92
|
+
and only the structuring element subregion is read/written for
|
|
93
|
+
channel channel_idx. Avoids decompressing the full channel.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
The updated interaction map (torch.Tensor for the default path; blosc2 NDArray for channel_idx path).
|
|
97
|
+
"""
|
|
98
|
+
if channel_idx is not None:
|
|
99
|
+
# blosc2 path: interaction_map is the full 4D NDArray; channel_idx selects the channel.
|
|
100
|
+
spatial_shape = interaction_map.shape[1:]
|
|
101
|
+
ndim = len(spatial_shape)
|
|
102
|
+
|
|
103
|
+
radius = tuple([sample_scalar(self.point_radius, d, spatial_shape) for d in range(ndim)])
|
|
104
|
+
strel = build_point(radius, self.use_distance_transform, binarize)
|
|
105
|
+
if intensity_scale != 1.0:
|
|
106
|
+
strel = strel * intensity_scale
|
|
107
|
+
|
|
108
|
+
bbox = [
|
|
109
|
+
[position[i] - strel.shape[i] // 2, position[i] + strel.shape[i] // 2 + strel.shape[i] % 2]
|
|
110
|
+
for i in range(ndim)
|
|
111
|
+
]
|
|
112
|
+
|
|
113
|
+
# detect if bbox is completely outside interaction_map
|
|
114
|
+
if any(i[1] < 0 for i in bbox) or any(i[0] > s for i, s in zip(bbox, spatial_shape)):
|
|
115
|
+
print("Point is outside the interaction map! Ignoring")
|
|
116
|
+
print(f"Position: {position}")
|
|
117
|
+
print(f"Interaction map shape: {spatial_shape}")
|
|
118
|
+
print(f"Point bbox would have been {bbox}")
|
|
119
|
+
return interaction_map
|
|
120
|
+
|
|
121
|
+
slices = tuple(slice(max(0, bbox[i][0]), min(spatial_shape[i], bbox[i][1])) for i in range(ndim))
|
|
122
|
+
structuring_slices = tuple(
|
|
123
|
+
slice(max(0, -bbox[i][0]), slices[i].stop - slices[i].start + max(0, -bbox[i][0])) for i in range(ndim)
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
target_slices = (channel_idx, *slices)
|
|
127
|
+
current_sub = np.asarray(interaction_map[target_slices])
|
|
128
|
+
strel_np = strel[structuring_slices].numpy().astype(current_sub.dtype)
|
|
129
|
+
np.maximum(current_sub, strel_np, out=current_sub)
|
|
130
|
+
interaction_map[target_slices] = current_sub
|
|
131
|
+
return interaction_map
|
|
132
|
+
|
|
133
|
+
# Default torch path: interaction_map is a 3D tensor for a single channel.
|
|
134
|
+
ndim = interaction_map.ndim
|
|
135
|
+
|
|
136
|
+
# Determine the radius for each dimension
|
|
137
|
+
radius = tuple([sample_scalar(self.point_radius, d, interaction_map.shape) for d in range(ndim)])
|
|
138
|
+
|
|
139
|
+
strel = build_point(radius, self.use_distance_transform, binarize)
|
|
140
|
+
if intensity_scale != 1.0:
|
|
141
|
+
strel = strel * intensity_scale
|
|
142
|
+
|
|
143
|
+
# Calculate slice range in each dimension, ensuring it is within the bounds of the interaction map
|
|
144
|
+
bbox = [
|
|
145
|
+
[position[i] - strel.shape[i] // 2, position[i] + strel.shape[i] // 2 + strel.shape[i] % 2]
|
|
146
|
+
for i in range(ndim)
|
|
147
|
+
]
|
|
148
|
+
# detect if bbox is completely outside interaction_map
|
|
149
|
+
if any([i[1] < 0 for i in bbox]) or any([i[0] > s for i, s in zip(bbox, interaction_map.shape)]):
|
|
150
|
+
print("Point is outside the interaction map! Ignoring")
|
|
151
|
+
print(f"Position: {position}")
|
|
152
|
+
print(f"Interaction map shape: {interaction_map.shape}")
|
|
153
|
+
print(f"Point bbox would have been {bbox}")
|
|
154
|
+
return interaction_map
|
|
155
|
+
slices = tuple(slice(max(0, bbox[i][0]), min(interaction_map.shape[i], bbox[i][1])) for i in range(ndim))
|
|
156
|
+
|
|
157
|
+
# Calculate where the resized structuring element should be placed within the slices
|
|
158
|
+
structuring_slices = tuple(
|
|
159
|
+
[slice(max(0, -bbox[i][0]), slices[i].stop - slices[i].start + max(0, -bbox[i][0])) for i in range(ndim)]
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Place the resized structuring element into the interaction map
|
|
163
|
+
torch.maximum(
|
|
164
|
+
interaction_map[slices], strel[structuring_slices].to(interaction_map.device), out=interaction_map[slices]
|
|
165
|
+
)
|
|
166
|
+
return interaction_map
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
import math
|
|
4
|
+
import pickle as pkl
|
|
5
|
+
import os
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from reader import BloscReaderWriter
|
|
9
|
+
from typing import List, Tuple, Union
|
|
10
|
+
import multiprocessing as mp
|
|
11
|
+
from functools import partial
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def sample_foreground_locations(
|
|
15
|
+
seg: np.ndarray,
|
|
16
|
+
classes_or_regions: Union[List[int], List[Tuple[int, ...]]],
|
|
17
|
+
seed: int = 1234,
|
|
18
|
+
verbose: bool = False,
|
|
19
|
+
):
|
|
20
|
+
num_samples = 100
|
|
21
|
+
# sparse
|
|
22
|
+
rndst = np.random.RandomState(seed)
|
|
23
|
+
class_locs = {}
|
|
24
|
+
foreground_mask = seg != 0
|
|
25
|
+
foreground_coords = np.argwhere(foreground_mask)
|
|
26
|
+
seg = seg[foreground_mask]
|
|
27
|
+
del foreground_mask
|
|
28
|
+
unique_labels = pd.unique(seg.ravel())
|
|
29
|
+
|
|
30
|
+
# We don't need more than 1e7 foreground samples. That's insanity. Cap here
|
|
31
|
+
if len(foreground_coords) > 1e7:
|
|
32
|
+
take_every = math.floor(len(foreground_coords) / 1e7)
|
|
33
|
+
# keep computation time reasonable
|
|
34
|
+
if verbose:
|
|
35
|
+
print(f"Subsampling foreground pixels 1:{take_every} for computational reasons")
|
|
36
|
+
foreground_coords = foreground_coords[::take_every]
|
|
37
|
+
seg = seg[::take_every]
|
|
38
|
+
|
|
39
|
+
for c in classes_or_regions:
|
|
40
|
+
k = c if not isinstance(c, list) else tuple(c)
|
|
41
|
+
|
|
42
|
+
# check if any of the labels are in seg, if not skip c
|
|
43
|
+
if isinstance(c, (tuple, list)):
|
|
44
|
+
if not any([ci in unique_labels for ci in c]):
|
|
45
|
+
class_locs[k] = []
|
|
46
|
+
continue
|
|
47
|
+
else:
|
|
48
|
+
if c not in unique_labels:
|
|
49
|
+
class_locs[k] = []
|
|
50
|
+
continue
|
|
51
|
+
|
|
52
|
+
if isinstance(c, (tuple, list)):
|
|
53
|
+
mask = seg == c[0]
|
|
54
|
+
for cc in c[1:]:
|
|
55
|
+
mask = mask | (seg == cc)
|
|
56
|
+
all_locs = foreground_coords[mask]
|
|
57
|
+
else:
|
|
58
|
+
mask = seg == c
|
|
59
|
+
all_locs = foreground_coords[mask]
|
|
60
|
+
if len(all_locs) == 0:
|
|
61
|
+
class_locs[k] = []
|
|
62
|
+
continue
|
|
63
|
+
target_num_samples = min(num_samples, len(all_locs))
|
|
64
|
+
|
|
65
|
+
selected = all_locs[rndst.choice(len(all_locs), target_num_samples, replace=False)]
|
|
66
|
+
class_locs[k] = selected
|
|
67
|
+
if verbose:
|
|
68
|
+
print(c, target_num_samples)
|
|
69
|
+
seg = seg[~mask]
|
|
70
|
+
foreground_coords = foreground_coords[~mask]
|
|
71
|
+
return class_locs
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def process_file(file, supervoxel_folder, bloscio):
|
|
75
|
+
out_file = os.path.join(supervoxel_folder, file.replace(".b2nd", ".pkl"))
|
|
76
|
+
|
|
77
|
+
if os.path.exists(out_file):
|
|
78
|
+
return
|
|
79
|
+
|
|
80
|
+
# Load the supervoxel file
|
|
81
|
+
supervoxel_arr, _ = bloscio.read(os.path.join(supervoxel_folder, file))
|
|
82
|
+
|
|
83
|
+
assert supervoxel_arr.ndim == 4, "The supervoxel array should have 4 dimensions, failed for file: " + file
|
|
84
|
+
|
|
85
|
+
all_class_locs = []
|
|
86
|
+
for submask in supervoxel_arr:
|
|
87
|
+
assert submask.ndim == 3, "The submask should have 3 dimensions, failed for file: " + file
|
|
88
|
+
all_class_locs.append(sample_foreground_locations(submask, [1]))
|
|
89
|
+
|
|
90
|
+
# Save the foreground locations
|
|
91
|
+
with open(out_file, "wb") as f:
|
|
92
|
+
pkl.dump(all_class_locs, f)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def generate_fg_locations(supervoxel_folder, num_processes=4):
|
|
96
|
+
"""
|
|
97
|
+
Generate the foreground locations for the supervoxels
|
|
98
|
+
|
|
99
|
+
Parameters:
|
|
100
|
+
supervoxel_folder (str): The path to the folder containing the supervoxel files
|
|
101
|
+
"""
|
|
102
|
+
bloscio = BloscReaderWriter()
|
|
103
|
+
# Load the supervoxel files
|
|
104
|
+
supervoxel_files = [f for f in os.listdir(supervoxel_folder) if f.endswith(".b2nd")]
|
|
105
|
+
|
|
106
|
+
if num_processes == 1:
|
|
107
|
+
for file in tqdm(supervoxel_files):
|
|
108
|
+
process_file(file, supervoxel_folder, bloscio)
|
|
109
|
+
else:
|
|
110
|
+
with mp.Pool(num_processes) as pool:
|
|
111
|
+
list(
|
|
112
|
+
tqdm(
|
|
113
|
+
pool.imap(
|
|
114
|
+
partial(process_file, supervoxel_folder=supervoxel_folder, bloscio=bloscio), supervoxel_files
|
|
115
|
+
),
|
|
116
|
+
total=len(supervoxel_files),
|
|
117
|
+
)
|
|
118
|
+
)
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
import math
|
|
3
|
+
from typing import Tuple, Union
|
|
4
|
+
import SimpleITK as sitk
|
|
5
|
+
import numpy as np
|
|
6
|
+
import blosc2
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class NfitiReaderWriter:
|
|
10
|
+
def __init__(self):
|
|
11
|
+
"""
|
|
12
|
+
NfitiReaderWriter class constructor. This class is responsible for reading and writing nifti files.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def read(self, image_path: str):
|
|
16
|
+
"""
|
|
17
|
+
Read the image file.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
sitk_image (SimpleITK.Image): The image object.
|
|
21
|
+
"""
|
|
22
|
+
sitk_image = sitk.ReadImage(image_path)
|
|
23
|
+
array = sitk.GetArrayFromImage(sitk_image)
|
|
24
|
+
return array, sitk_image
|
|
25
|
+
|
|
26
|
+
def write(self, array: np.ndarray, sitk_image: sitk.Image, output_path: str):
|
|
27
|
+
"""
|
|
28
|
+
Write the image file.
|
|
29
|
+
|
|
30
|
+
Parameters:
|
|
31
|
+
sitk_image (SimpleITK.Image): The image object.
|
|
32
|
+
output_path (str): The path to save the image file.
|
|
33
|
+
"""
|
|
34
|
+
out_image = sitk.GetImageFromArray(array)
|
|
35
|
+
out_image.SetDirection(sitk_image.GetDirection())
|
|
36
|
+
out_image.SetOrigin(sitk_image.GetOrigin())
|
|
37
|
+
out_image.SetSpacing(sitk_image.GetSpacing())
|
|
38
|
+
sitk.WriteImage(out_image, output_path)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class BloscReaderWriter:
|
|
42
|
+
def __init__(self):
|
|
43
|
+
"""
|
|
44
|
+
BloscReaderWriter class constructor. This class is responsible for reading and writing blosc files.
|
|
45
|
+
"""
|
|
46
|
+
blosc2.set_nthreads(1)
|
|
47
|
+
|
|
48
|
+
def read(self, image_path: str):
|
|
49
|
+
"""
|
|
50
|
+
Read the image file.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
array (np.ndarray): The image array.
|
|
54
|
+
"""
|
|
55
|
+
dparams = {"nthreads": 1}
|
|
56
|
+
im = blosc2.open(urlpath=image_path, mode="r", dparams=dparams, mmap_mode="r")
|
|
57
|
+
return im[:], None
|
|
58
|
+
|
|
59
|
+
def write(self, array: np.ndarray, properties, output_path: str):
|
|
60
|
+
"""
|
|
61
|
+
Write the image file.
|
|
62
|
+
|
|
63
|
+
Parameters:
|
|
64
|
+
array (np.ndarray): The image array.
|
|
65
|
+
properties: Unused
|
|
66
|
+
output_path (str): The path to save the image file.
|
|
67
|
+
"""
|
|
68
|
+
cparams = {
|
|
69
|
+
"codec": blosc2.Codec.ZSTD,
|
|
70
|
+
# 'filters': [blosc2.Filter.SHUFFLE],
|
|
71
|
+
# 'splitmode': blosc2.SplitMode.ALWAYS_SPLIT,
|
|
72
|
+
"clevel": 8,
|
|
73
|
+
}
|
|
74
|
+
chunks, blocks = None, None # self.comp_blosc2_params(array.shape, [192, 192, 192], array.itemsize)
|
|
75
|
+
# print(output_filename_truncated, data.shape, seg.shape, blocks, chunks, blocks_seg, chunks_seg, data.dtype, seg.dtype)
|
|
76
|
+
blosc2.asarray(
|
|
77
|
+
np.ascontiguousarray(array),
|
|
78
|
+
urlpath=output_path,
|
|
79
|
+
chunks=chunks,
|
|
80
|
+
blocks=blocks,
|
|
81
|
+
cparams=cparams,
|
|
82
|
+
mmap_mode="w+",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def comp_blosc2_params(
|
|
86
|
+
self,
|
|
87
|
+
image_size: Tuple[int, int, int, int],
|
|
88
|
+
patch_size: Union[Tuple[int, int], Tuple[int, int, int]],
|
|
89
|
+
bytes_per_pixel: int = 4, # 4 byte are float32
|
|
90
|
+
l1_cache_size_per_core_in_bytes=32768, # 1 Kibibyte (KiB) = 2^10 Byte; 32 KiB = 32768 Byte
|
|
91
|
+
l3_cache_size_per_core_in_bytes=1441792,
|
|
92
|
+
# 1 Mibibyte (MiB) = 2^20 Byte = 1.048.576 Byte; 1.375MiB = 1441792 Byte
|
|
93
|
+
safety_factor: float = 0.8, # we dont will the caches to the brim. 0.8 means we target 80% of the caches
|
|
94
|
+
):
|
|
95
|
+
"""
|
|
96
|
+
Computes a recommended block and chunk size for saving arrays with blosc v2.
|
|
97
|
+
|
|
98
|
+
Bloscv2 NDIM doku: "Remember that having a second partition means that we have better flexibility to fit the
|
|
99
|
+
different partitions at the different CPU cache levels; typically the first partition (aka chunks) should
|
|
100
|
+
be made to fit in L3 cache, whereas the second partition (aka blocks) should rather fit in L2/L1 caches
|
|
101
|
+
(depending on whether compression ratio or speed is desired)."
|
|
102
|
+
(https://www.blosc.org/posts/blosc2-ndim-intro/)
|
|
103
|
+
-> We are not 100% sure how to optimize for that. For now we try to fit the uncompressed block in L1. This
|
|
104
|
+
might spill over into L2, which is fine in our books.
|
|
105
|
+
|
|
106
|
+
Note: this is optimized for nnU-Net dataloading where each read operation is done by one core. We cannot use threading
|
|
107
|
+
|
|
108
|
+
Cache default values computed based on old Intel 4110 CPU with 32K L1, 128K L2 and 1408K L3 cache per core.
|
|
109
|
+
We cannot optimize further for more modern CPUs with more cache as the data will need be be read by the
|
|
110
|
+
old ones as well.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
patch_size: Image size, must be 4D (c, x, y, z). For 2D images, make x=1
|
|
114
|
+
patch_size: Patch size, spatial dimensions only. So (x, y) or (x, y, z)
|
|
115
|
+
bytes_per_pixel: Number of bytes per element. Example: float32 -> 4 bytes
|
|
116
|
+
l1_cache_size_per_core_in_bytes: The size of the L1 cache per core in Bytes.
|
|
117
|
+
l3_cache_size_per_core_in_bytes: The size of the L3 cache exclusively accessible by each core. Usually the global size of the L3 cache divided by the number of cores.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
The recommended block and the chunk size.
|
|
121
|
+
"""
|
|
122
|
+
# Fabians code is ugly, but eh
|
|
123
|
+
|
|
124
|
+
num_channels = image_size[0]
|
|
125
|
+
if len(patch_size) == 2:
|
|
126
|
+
patch_size = [1, *patch_size]
|
|
127
|
+
patch_size = np.array(patch_size)
|
|
128
|
+
block_size = np.array((num_channels, *[2 ** (max(0, math.ceil(math.log2(i)))) for i in patch_size]))
|
|
129
|
+
|
|
130
|
+
# shrink the block size until it fits in L1
|
|
131
|
+
estimated_nbytes_block = np.prod(block_size) * bytes_per_pixel
|
|
132
|
+
while estimated_nbytes_block > (l1_cache_size_per_core_in_bytes * safety_factor):
|
|
133
|
+
# pick largest deviation from patch_size that is not 1
|
|
134
|
+
axis_order = np.argsort(block_size[1:] / patch_size)[::-1]
|
|
135
|
+
idx = 0
|
|
136
|
+
picked_axis = axis_order[idx]
|
|
137
|
+
while block_size[picked_axis + 1] == 1 or block_size[picked_axis + 1] == 1:
|
|
138
|
+
idx += 1
|
|
139
|
+
picked_axis = axis_order[idx]
|
|
140
|
+
# now reduce that axis to the next lowest power of 2
|
|
141
|
+
block_size[picked_axis + 1] = 2 ** (max(0, math.floor(math.log2(block_size[picked_axis + 1] - 1))))
|
|
142
|
+
block_size[picked_axis + 1] = min(block_size[picked_axis + 1], image_size[picked_axis + 1])
|
|
143
|
+
estimated_nbytes_block = np.prod(block_size) * bytes_per_pixel
|
|
144
|
+
|
|
145
|
+
block_size = np.array([min(i, j) for i, j in zip(image_size, block_size)])
|
|
146
|
+
|
|
147
|
+
# note: there is no use extending the chunk size to 3d when we have a 2d patch size! This would unnecessarily
|
|
148
|
+
# load data into L3
|
|
149
|
+
# now tile the blocks into chunks until we hit image_size or the l3 cache per core limit
|
|
150
|
+
chunk_size = deepcopy(block_size)
|
|
151
|
+
estimated_nbytes_chunk = np.prod(chunk_size) * bytes_per_pixel
|
|
152
|
+
while estimated_nbytes_chunk < (l3_cache_size_per_core_in_bytes * safety_factor):
|
|
153
|
+
if patch_size[0] == 1 and all([i == j for i, j in zip(chunk_size[2:], image_size[2:])]):
|
|
154
|
+
break
|
|
155
|
+
if all([i == j for i, j in zip(chunk_size, image_size)]):
|
|
156
|
+
break
|
|
157
|
+
# find axis that deviates from block_size the most
|
|
158
|
+
axis_order = np.argsort(chunk_size[1:] / block_size[1:])
|
|
159
|
+
idx = 0
|
|
160
|
+
picked_axis = axis_order[idx]
|
|
161
|
+
while chunk_size[picked_axis + 1] == image_size[picked_axis + 1] or patch_size[picked_axis] == 1:
|
|
162
|
+
idx += 1
|
|
163
|
+
picked_axis = axis_order[idx]
|
|
164
|
+
chunk_size[picked_axis + 1] += block_size[picked_axis + 1]
|
|
165
|
+
chunk_size[picked_axis + 1] = min(chunk_size[picked_axis + 1], image_size[picked_axis + 1])
|
|
166
|
+
estimated_nbytes_chunk = np.prod(chunk_size) * bytes_per_pixel
|
|
167
|
+
if np.mean([i / j for i, j in zip(chunk_size[1:], patch_size)]) > 1.5:
|
|
168
|
+
# chunk size should not exceed patch size * 1.5 on average
|
|
169
|
+
chunk_size[picked_axis + 1] -= block_size[picked_axis + 1]
|
|
170
|
+
break
|
|
171
|
+
# better safe than sorry
|
|
172
|
+
chunk_size = [min(i, j) for i, j in zip(image_size, chunk_size)]
|
|
173
|
+
|
|
174
|
+
# print(image_size, chunk_size, block_size)
|
|
175
|
+
return tuple(block_size), tuple(chunk_size)
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import yaml
|
|
3
|
+
import argparse
|
|
4
|
+
from typing import List, Tuple
|
|
5
|
+
import numpy as np
|
|
6
|
+
from supervoxel import SuperVoxelGenerator
|
|
7
|
+
from metadata import generate_fg_locations
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
import torch
|
|
10
|
+
import gc
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def list_and_clear_tensors_on_gpu():
|
|
14
|
+
|
|
15
|
+
gpu_tensors = [] # Collect GPU tensors to analyze them
|
|
16
|
+
tensor_sizes = [] # Collect tensor sizes to analyze them
|
|
17
|
+
|
|
18
|
+
for obj in gc.get_objects():
|
|
19
|
+
try:
|
|
20
|
+
if torch.is_tensor(obj) and obj.is_cuda:
|
|
21
|
+
print(f"Tensor: {type(obj)}, size: {obj.size()}, device: {obj.device}")
|
|
22
|
+
gpu_tensors.append(obj)
|
|
23
|
+
tensor_sizes.append(obj.nbytes)
|
|
24
|
+
except Exception as e:
|
|
25
|
+
pass # Handle any inspection errors
|
|
26
|
+
|
|
27
|
+
# Delete all tensors found
|
|
28
|
+
print(f"Found {len(gpu_tensors)} tensors on GPU.")
|
|
29
|
+
bytes_sum = sum(tensor_sizes)
|
|
30
|
+
print(f"Total size: {bytes_sum / 1024 ** 3} GB")
|
|
31
|
+
for tensor in gpu_tensors:
|
|
32
|
+
del tensor # Remove local references to the tensors
|
|
33
|
+
|
|
34
|
+
# Trigger garbage collection
|
|
35
|
+
gc.collect()
|
|
36
|
+
|
|
37
|
+
# Empty CUDA cache
|
|
38
|
+
torch.cuda.empty_cache()
|
|
39
|
+
|
|
40
|
+
print("Cleared GPU tensors and memory.")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def run(input_folder: str, output_folder: str, config: str):
|
|
44
|
+
"""
|
|
45
|
+
Run the SuperVoxel generation using SAM
|
|
46
|
+
|
|
47
|
+
:param input_folder: Path to folder containing the files to process
|
|
48
|
+
:param output_folder: Path to output folder. If not provided, the output will be saved in the same folder as the dataset.
|
|
49
|
+
"""
|
|
50
|
+
if output_folder is None:
|
|
51
|
+
output_folder = os.path.join(input_folder, os.pardir, "supervoxel")
|
|
52
|
+
os.makedirs(output_folder, exist_ok=True)
|
|
53
|
+
|
|
54
|
+
# Load congiguration file
|
|
55
|
+
with open(config, "r") as file:
|
|
56
|
+
config = yaml.safe_load(file)
|
|
57
|
+
|
|
58
|
+
gen = SuperVoxelGenerator(input_folder, output_folder, config)
|
|
59
|
+
|
|
60
|
+
# List of files
|
|
61
|
+
list_of_files = [
|
|
62
|
+
f for f in os.listdir(input_folder) if f.endswith(config["file_format"]) and config["excluded_strings"] not in f
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
# Shuffle the list of files
|
|
66
|
+
np.random.shuffle(list_of_files)
|
|
67
|
+
for image_name in tqdm(list_of_files):
|
|
68
|
+
out_path = os.path.join(output_folder, image_name)
|
|
69
|
+
if os.path.exists(out_path):
|
|
70
|
+
continue
|
|
71
|
+
|
|
72
|
+
image_data, metadata = gen.reader_writer.read(os.path.join(input_folder, image_name))
|
|
73
|
+
|
|
74
|
+
if len(image_data.shape) == 4:
|
|
75
|
+
image_data = image_data[0]
|
|
76
|
+
# chatch OOM error
|
|
77
|
+
try:
|
|
78
|
+
out_img = None
|
|
79
|
+
out_img = gen.sam_supervoxel(image_data)
|
|
80
|
+
except torch.OutOfMemoryError:
|
|
81
|
+
print("OOM error for image:", image_name)
|
|
82
|
+
if out_img is not None:
|
|
83
|
+
del out_img
|
|
84
|
+
if gen is not None:
|
|
85
|
+
del gen
|
|
86
|
+
gc.collect()
|
|
87
|
+
torch.cuda.empty_cache()
|
|
88
|
+
torch.cuda.ipc_collect()
|
|
89
|
+
gen = SuperVoxelGenerator(input_folder, output_folder, config)
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
if os.path.exists(out_path):
|
|
93
|
+
continue
|
|
94
|
+
gen.reader_writer.write(out_img, metadata, out_path)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def run_entrypoint():
|
|
98
|
+
parser = argparse.ArgumentParser(description="Run SuperVoxel generation using SAM")
|
|
99
|
+
parser.add_argument(
|
|
100
|
+
"-i",
|
|
101
|
+
"-input_folder",
|
|
102
|
+
type=str,
|
|
103
|
+
help="Path to folder containing the images. They can be in raw Nifit format \
|
|
104
|
+
or using the new nnUNet supported bloscv2, depending on the file format provided in the config file.",
|
|
105
|
+
)
|
|
106
|
+
parser.add_argument(
|
|
107
|
+
"-o",
|
|
108
|
+
"-output_folder",
|
|
109
|
+
type=str,
|
|
110
|
+
help="Path to output folder. If not provided, the output will be saved in the \
|
|
111
|
+
same parent folder as the dataset and named 'supervoxel'.",
|
|
112
|
+
)
|
|
113
|
+
parser.add_argument(
|
|
114
|
+
"-c",
|
|
115
|
+
"-config",
|
|
116
|
+
type=str,
|
|
117
|
+
help="Path to configuration file containing the parameters for the SuperVoxel generation.",
|
|
118
|
+
default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "../configs/nnUNet_preprocessed.yaml"),
|
|
119
|
+
)
|
|
120
|
+
args = parser.parse_args()
|
|
121
|
+
|
|
122
|
+
run(args.i, args.o, args.c)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def run_save_fg_locations_entrypoint():
|
|
126
|
+
parser = argparse.ArgumentParser(description="Run generation of pkl files for nnUNet")
|
|
127
|
+
parser.add_argument(
|
|
128
|
+
"-supervoxel_folder",
|
|
129
|
+
type=str,
|
|
130
|
+
help="Path to folder containing the supervoxel masks",
|
|
131
|
+
default="/home/m574s/PhD/projects/SuperVoxel/supervoxels/",
|
|
132
|
+
)
|
|
133
|
+
parser.add_argument("-np", "-num_processes", type=int, help="Number of processes to use")
|
|
134
|
+
args = parser.parse_args()
|
|
135
|
+
|
|
136
|
+
generate_fg_locations(args.supervoxel_folder, args.np)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from hydra import initialize_config_module
|
|
8
|
+
from hydra.core.global_hydra import GlobalHydra
|
|
9
|
+
|
|
10
|
+
if not GlobalHydra.instance().is_initialized():
|
|
11
|
+
initialize_config_module("sam2", version_base="1.2")
|