nnInteractive 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nnInteractive/__init__.py +3 -0
- nnInteractive/inference/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
- nnInteractive/inference/inference_session.py +1400 -0
- nnInteractive/interaction/__init__.py +0 -0
- nnInteractive/interaction/point.py +166 -0
- nnInteractive/supervoxel/setup.py +4 -0
- nnInteractive/supervoxel/src/metadata.py +118 -0
- nnInteractive/supervoxel/src/reader.py +175 -0
- nnInteractive/supervoxel/src/run.py +136 -0
- nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
- nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
- nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
- nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
- nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
- nnInteractive/supervoxel/src/sam2/setup.py +174 -0
- nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
- nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
- nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
- nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
- nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
- nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
- nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
- nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
- nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
- nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
- nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
- nnInteractive/supervoxel/src/supervoxel.py +198 -0
- nnInteractive/trainer/__init__.py +0 -0
- nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
- nnInteractive/utils/__init__.py +0 -0
- nnInteractive/utils/bboxes.py +217 -0
- nnInteractive/utils/checkpoint_cleansing.py +9 -0
- nnInteractive/utils/crop.py +268 -0
- nnInteractive/utils/erosion_dilation.py +48 -0
- nnInteractive/utils/inference_helpers.py +45 -0
- nnInteractive/utils/os_shennanigans.py +16 -0
- nnInteractive/utils/rounding.py +13 -0
- nninteractive-2.0.0.dist-info/METADATA +511 -0
- nninteractive-2.0.0.dist-info/RECORD +76 -0
- nninteractive-2.0.0.dist-info/WHEEL +5 -0
- nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
- nninteractive-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from typing import List
|
|
4
|
+
import numpy as np
|
|
5
|
+
from reader import NfitiReaderWriter, BloscReaderWriter
|
|
6
|
+
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
|
7
|
+
from sam2.sam2.build_sam import build_sam2_video_predictor
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
import concurrent.futures
|
|
10
|
+
import cc3d
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SuperVoxelGenerator:
|
|
14
|
+
def __init__(self, input_dir, output_dir: str, config: dict):
|
|
15
|
+
"""
|
|
16
|
+
SuperVoxelGenerator class constructor. This class is responsible for generating supervoxels
|
|
17
|
+
segmentation masks from a list of image paths.
|
|
18
|
+
|
|
19
|
+
Parameters:
|
|
20
|
+
input_dir (str): The folder containing the images to process.
|
|
21
|
+
output_dir (str): The directory where the segmentation masks will be saved. output_dir / name_of_image
|
|
22
|
+
config (dict): A dictionary containing the configuration parameters for the SuperVoxel generation.
|
|
23
|
+
"""
|
|
24
|
+
self.input_dir = input_dir
|
|
25
|
+
self.output_dir = output_dir
|
|
26
|
+
self.file_format = config["file_format"]
|
|
27
|
+
|
|
28
|
+
self.reader_writer = {".nii.gz": NfitiReaderWriter(), ".b2nd": BloscReaderWriter()}[self.file_format]
|
|
29
|
+
print(f"Using {self.reader_writer} to read and write files")
|
|
30
|
+
|
|
31
|
+
# Get the list of files to process
|
|
32
|
+
self.list_of_files = [
|
|
33
|
+
f
|
|
34
|
+
for f in os.listdir(input_dir)
|
|
35
|
+
if f.endswith(config["file_format"]) and config["excluded_strings"] not in f
|
|
36
|
+
]
|
|
37
|
+
self.config = config
|
|
38
|
+
|
|
39
|
+
self.sam = sam_model_registry["vit_h"](checkpoint=self.config["sam1_checkpoint"]).to("cuda")
|
|
40
|
+
self.mask_generator = SamAutomaticMaskGenerator(
|
|
41
|
+
model=self.sam,
|
|
42
|
+
points_per_side=48,
|
|
43
|
+
points_per_batch=256,
|
|
44
|
+
pred_iou_thresh=0.85,
|
|
45
|
+
stability_score_thresh=0.92,
|
|
46
|
+
box_nms_thresh=0.6,
|
|
47
|
+
crop_nms_thresh=0.6,
|
|
48
|
+
crop_n_layers=1,
|
|
49
|
+
crop_n_points_downscale_factor=2,
|
|
50
|
+
min_mask_region_area=192,
|
|
51
|
+
)
|
|
52
|
+
model_cfg = "sam2/configs/sam2.1/sam2.1_hiera_l.yaml"
|
|
53
|
+
self.sam2_predictor = build_sam2_video_predictor(model_cfg, ckpt_path=self.config["sam2_checkpoint"])
|
|
54
|
+
|
|
55
|
+
def sam2_propagation(self, sam2_predictor_state, image_data: np.ndarray, masks: List[np.ndarray], slice_idx: int):
|
|
56
|
+
"""
|
|
57
|
+
Propagate the masks using SAM2
|
|
58
|
+
|
|
59
|
+
Parameters:
|
|
60
|
+
image_data (np.ndarray): The image data to process. Shape: (z, y, x)
|
|
61
|
+
masks (List[np.ndarray]): A list of masks to propagate. Shape: (z, y, x)
|
|
62
|
+
slice_idx (int): The index of the slice which contains the masks
|
|
63
|
+
"""
|
|
64
|
+
propagated_masks = np.zeros((len(masks), *image_data.shape))
|
|
65
|
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
|
66
|
+
|
|
67
|
+
for obj_id, m in enumerate(masks):
|
|
68
|
+
# add new prompts and instantly get the output on the same frame
|
|
69
|
+
frame_idx, object_ids, _ = self.sam2_predictor.add_new_mask(
|
|
70
|
+
sam2_predictor_state, slice_idx, obj_id, m.astype(np.int8)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# propagate the prompts to get masklets throughout the video
|
|
74
|
+
for frame_idx, object_ids, out_masks in self.sam2_predictor.propagate_in_video(sam2_predictor_state):
|
|
75
|
+
for obj_n, out_mask in zip(object_ids, out_masks):
|
|
76
|
+
propagated_masks[obj_n, frame_idx] = (out_mask > 0).detach().cpu().numpy()[0].astype(np.uint8)
|
|
77
|
+
|
|
78
|
+
# reset state to predict in other direction
|
|
79
|
+
self.sam2_predictor.reset_state(sam2_predictor_state)
|
|
80
|
+
|
|
81
|
+
# flip images order
|
|
82
|
+
sam2_predictor_state["images"] = torch.flip(sam2_predictor_state["images"], dims=(0,))
|
|
83
|
+
max_frame = sam2_predictor_state["images"].shape[0] - 1
|
|
84
|
+
|
|
85
|
+
for obj_id, m in enumerate(masks):
|
|
86
|
+
# add new prompts and instantly get the output on the same frame
|
|
87
|
+
frame_idx, object_ids, _ = self.sam2_predictor.add_new_mask(
|
|
88
|
+
sam2_predictor_state, max_frame - slice_idx, obj_id, m.astype(np.int8)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# propagate the prompts to get masklets throughout the video
|
|
92
|
+
for frame_idx, object_ids, out_masks in self.sam2_predictor.propagate_in_video(sam2_predictor_state):
|
|
93
|
+
for obj_n, out_mask in zip(object_ids, out_masks):
|
|
94
|
+
propagated_masks[obj_n, max_frame - frame_idx] = (
|
|
95
|
+
(out_mask > 0).detach().cpu().numpy()[0].astype(np.uint8)
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return propagated_masks
|
|
99
|
+
|
|
100
|
+
def remove_other_components(self, masks, frame_idx):
|
|
101
|
+
"""
|
|
102
|
+
Remove components other than the one overlaping with the seed mask in the given frame
|
|
103
|
+
|
|
104
|
+
Parameters:
|
|
105
|
+
mask (np.ndarray): The mask to process. Shape: (n, z, y, x)
|
|
106
|
+
frame_idx (int): The index of the frame
|
|
107
|
+
"""
|
|
108
|
+
filtered_masks = []
|
|
109
|
+
for i, m in enumerate(masks):
|
|
110
|
+
m_cc, num_cc = cc3d.connected_components(m, connectivity=26, binary_image=True, return_N=True)
|
|
111
|
+
|
|
112
|
+
# keep all the components that overlap with the seed mask
|
|
113
|
+
filtered_components = []
|
|
114
|
+
for n in range(1, num_cc + 1):
|
|
115
|
+
if np.sum(m_cc[frame_idx] == n) > 0:
|
|
116
|
+
filtered_components.append(n)
|
|
117
|
+
if len(filtered_components) != 0:
|
|
118
|
+
final_component = np.random.choice(filtered_components)
|
|
119
|
+
filtered_masks.append((m_cc == final_component).astype(np.uint8))
|
|
120
|
+
filtered_masks = np.stack(filtered_masks, axis=0)
|
|
121
|
+
return filtered_masks
|
|
122
|
+
|
|
123
|
+
def sam_supervoxel(self, image_data: np.ndarray):
|
|
124
|
+
"""
|
|
125
|
+
Generate the supervoxels segmentation masks using SAM and SAM2
|
|
126
|
+
|
|
127
|
+
Parameters:
|
|
128
|
+
image_data (np.ndarray): The image data to process. Shape: (z, y, x)
|
|
129
|
+
"""
|
|
130
|
+
# Normalize between 0 to 1 using 95 percentile
|
|
131
|
+
image_data = (image_data - np.percentile(image_data, 5)) / (
|
|
132
|
+
np.percentile(image_data, 95) - np.percentile(image_data, 5)
|
|
133
|
+
)
|
|
134
|
+
data_shape = image_data.shape
|
|
135
|
+
|
|
136
|
+
# Sample random slice with Gaussian probability
|
|
137
|
+
z_len = data_shape[0]
|
|
138
|
+
slice_probabilitys = np.exp(-np.linspace(-1, 1, data_shape[0]) ** 2 * 2)
|
|
139
|
+
slice_idx = np.random.choice(z_len, p=slice_probabilitys / slice_probabilitys.sum())
|
|
140
|
+
|
|
141
|
+
def generate_sam_masks():
|
|
142
|
+
return self.mask_generator.generate(image_data[slice_idx, ..., None].repeat(3, axis=2))
|
|
143
|
+
|
|
144
|
+
def init_sam2_predictor():
|
|
145
|
+
return self.sam2_predictor.init_state(image_data)
|
|
146
|
+
|
|
147
|
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
148
|
+
future_sam_masks = executor.submit(generate_sam_masks)
|
|
149
|
+
future_sam2_predictor = executor.submit(init_sam2_predictor)
|
|
150
|
+
|
|
151
|
+
masks = future_sam_masks.result()
|
|
152
|
+
sam2_predictor_state = future_sam2_predictor.result()
|
|
153
|
+
|
|
154
|
+
# Pick n masks randomly
|
|
155
|
+
selected_masks = (
|
|
156
|
+
np.random.choice(masks, size=self.config["masks_per_image"], replace=False)
|
|
157
|
+
if len(masks) > self.config["masks_per_image"]
|
|
158
|
+
else masks
|
|
159
|
+
)
|
|
160
|
+
selected_masks = [m["segmentation"] for m in selected_masks]
|
|
161
|
+
|
|
162
|
+
# Propagate the masks using SAM2
|
|
163
|
+
propagated_masks = self.sam2_propagation(sam2_predictor_state, image_data, selected_masks, slice_idx)
|
|
164
|
+
self.sam2_predictor.reset_state(sam2_predictor_state)
|
|
165
|
+
|
|
166
|
+
# Remove other components
|
|
167
|
+
propagated_masks = self.remove_other_components(propagated_masks, slice_idx)
|
|
168
|
+
|
|
169
|
+
# Binrize the masks
|
|
170
|
+
propagated_masks = (propagated_masks > 0).astype(np.uint8)
|
|
171
|
+
|
|
172
|
+
# viewer = napari.Viewer()
|
|
173
|
+
# viewer.add_image(image_data)
|
|
174
|
+
# viewer.add_labels(propagated_masks.astype(np.uint8))
|
|
175
|
+
# napari.run()
|
|
176
|
+
|
|
177
|
+
return propagated_masks
|
|
178
|
+
|
|
179
|
+
def process_images(self):
|
|
180
|
+
"""
|
|
181
|
+
Process the images in the list_of_images and generate the supervoxels segmentation masks.
|
|
182
|
+
"""
|
|
183
|
+
print(f"Processing {self.list_of_files} images")
|
|
184
|
+
|
|
185
|
+
for image_name in tqdm(self.list_of_files):
|
|
186
|
+
out_path = os.path.join(self.output_dir, image_name)
|
|
187
|
+
if os.path.exists(out_path):
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
image_data, metadata = self.reader_writer.read(os.path.join(self.input_dir, image_name))
|
|
191
|
+
|
|
192
|
+
if len(image_data.shape) == 4:
|
|
193
|
+
image_data = image_data[0]
|
|
194
|
+
out_img = self.sam_supervoxel(image_data)
|
|
195
|
+
|
|
196
|
+
if os.path.exists(out_path):
|
|
197
|
+
continue
|
|
198
|
+
self.reader_writer.write(out_img, metadata, out_path)
|
|
File without changes
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
|
|
2
|
+
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class nnInteractiveTrainer_stub:
|
|
7
|
+
def __init__(self, *args, **kwargs):
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
@staticmethod
|
|
11
|
+
def build_network_architecture(
|
|
12
|
+
plans_manager: PlansManager,
|
|
13
|
+
configuration_manager: ConfigurationManager,
|
|
14
|
+
num_input_channels: int,
|
|
15
|
+
num_output_channels: int,
|
|
16
|
+
enable_deep_supervision: bool = True,
|
|
17
|
+
) -> nn.Module:
|
|
18
|
+
return nnUNetTrainer.build_network_architecture(
|
|
19
|
+
plans_manager,
|
|
20
|
+
configuration_manager,
|
|
21
|
+
num_input_channels,
|
|
22
|
+
2, # nnunet handles one class segmentation still as CE so we need 2 outputs.
|
|
23
|
+
enable_deep_supervision,
|
|
24
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
from time import time
|
|
2
|
+
from time import time
|
|
3
|
+
from typing import List, Union, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def generate_bounding_boxes(
|
|
10
|
+
mask,
|
|
11
|
+
bbox_size=(192, 192, 192),
|
|
12
|
+
stride: Union[List[int], Tuple[int, int, int], str] = (16, 16, 16),
|
|
13
|
+
margin=(10, 10, 10),
|
|
14
|
+
max_depth=5,
|
|
15
|
+
current_depth=0,
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Generate overlapping bounding boxes to cover a 3D binary segmentation mask using PyTorch tensors.
|
|
19
|
+
|
|
20
|
+
Parameters:
|
|
21
|
+
- mask: 3D PyTorch tensor with values 0 or 1 (binary mask)
|
|
22
|
+
- bbox_size: Tuple or list of three integers specifying the size of bounding boxes per dimension (x, y, z)
|
|
23
|
+
- stride: Tuple or list of three integers specifying the stride for subsampling centers per dimension
|
|
24
|
+
- margin: Tuple or list of three integers specifying the margin to leave uncovered per dimension
|
|
25
|
+
- max_depth: Maximum recursion depth to prevent infinite recursion
|
|
26
|
+
- current_depth: Current recursion depth (used internally)
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
- List of tuples [(min_coords, max_coords), ...], where min_coords and max_coords are lists [x, y, z] defining each box
|
|
30
|
+
as a half-open interval [min_coords, max_coords).
|
|
31
|
+
"""
|
|
32
|
+
if not torch.any(mask):
|
|
33
|
+
return []
|
|
34
|
+
|
|
35
|
+
# Prevent infinite recursion
|
|
36
|
+
if current_depth > max_depth:
|
|
37
|
+
# print('random fallback due to max recursion depth')
|
|
38
|
+
return random_sampling_fallback(mask, bbox_size, margin, 25)
|
|
39
|
+
|
|
40
|
+
# Ensure bbox_size, stride, and margin are lists
|
|
41
|
+
bbox_size = list(bbox_size)
|
|
42
|
+
margin = list(margin)
|
|
43
|
+
|
|
44
|
+
# Compute half sizes for each dimension
|
|
45
|
+
half_size = [bs // 2 for bs in bbox_size]
|
|
46
|
+
# Adjust end offsets to ensure full bbox_size (handles odd sizes)
|
|
47
|
+
end_offset = [bs - hs for bs, hs in zip(bbox_size, half_size)] # e.g., 193 - 96 = 97
|
|
48
|
+
|
|
49
|
+
# Step 1: Find all object voxels
|
|
50
|
+
object_voxels = torch.nonzero(mask, as_tuple=False)
|
|
51
|
+
if object_voxels.numel() == 0:
|
|
52
|
+
return []
|
|
53
|
+
|
|
54
|
+
# Step 2: Compute the object's bounding box to limit potential centers
|
|
55
|
+
min_coords = object_voxels.min(dim=0)[0]
|
|
56
|
+
max_coords = object_voxels.max(dim=0)[0]
|
|
57
|
+
|
|
58
|
+
if isinstance(stride, str) and stride == "auto":
|
|
59
|
+
stride = [max(1, round((j.item() - i.item()) / 4)) for i, j in zip(min_coords, max_coords)]
|
|
60
|
+
|
|
61
|
+
stride = list(stride)
|
|
62
|
+
# print('stride', stride)
|
|
63
|
+
# print('bbox', [[i, j] for i, j in zip(min_coords, max_coords)])
|
|
64
|
+
|
|
65
|
+
# Step 3: Generate potential centers within the object's bounding box
|
|
66
|
+
potential_centers = []
|
|
67
|
+
for x in range(max(0, min_coords[0].item()), min(mask.shape[0], max_coords[0].item() + 1), stride[0]):
|
|
68
|
+
for y in range(max(0, min_coords[1].item()), min(mask.shape[1], max_coords[1].item() + 1), stride[1]):
|
|
69
|
+
for z in range(max(0, min_coords[2].item()), min(mask.shape[2], max_coords[2].item() + 1), stride[2]):
|
|
70
|
+
if mask[x, y, z]:
|
|
71
|
+
potential_centers.append([x, y, z])
|
|
72
|
+
# print(f'got {len(potential_centers)} center candidates')
|
|
73
|
+
|
|
74
|
+
if len(potential_centers) == 0:
|
|
75
|
+
return generate_bounding_boxes(
|
|
76
|
+
mask, bbox_size, [max(1, s // 2) for s in stride], margin, max_depth, current_depth + 1
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
potential_centers = torch.tensor(potential_centers, device=mask.device)
|
|
80
|
+
|
|
81
|
+
# Step 4: Greedy set cover algorithm
|
|
82
|
+
uncovered = mask.clone().byte() # Use byte tensor for efficiency
|
|
83
|
+
bboxes = []
|
|
84
|
+
|
|
85
|
+
while len(potential_centers) > 0 and uncovered.any():
|
|
86
|
+
best_center = None
|
|
87
|
+
best_covered = 0
|
|
88
|
+
best_bounds = None
|
|
89
|
+
|
|
90
|
+
# Find the center that covers the most uncovered voxels
|
|
91
|
+
idx = 0
|
|
92
|
+
while idx < len(potential_centers):
|
|
93
|
+
center = potential_centers[idx]
|
|
94
|
+
c_x, c_y, c_z = center
|
|
95
|
+
x_start = max(0, c_x - half_size[0] + margin[0])
|
|
96
|
+
x_end = min(mask.shape[0], c_x + end_offset[0] - margin[0]) # Use end_offset for odd sizes
|
|
97
|
+
y_start = max(0, c_y - half_size[1] + margin[1])
|
|
98
|
+
y_end = min(mask.shape[1], c_y + end_offset[1] - margin[1])
|
|
99
|
+
z_start = max(0, c_z - half_size[2] + margin[2])
|
|
100
|
+
z_end = min(mask.shape[2], c_z + end_offset[2] - margin[2])
|
|
101
|
+
|
|
102
|
+
num_covered = uncovered[x_start:x_end, y_start:y_end, z_start:z_end].sum().item()
|
|
103
|
+
if num_covered > best_covered:
|
|
104
|
+
best_covered = num_covered
|
|
105
|
+
best_center = idx
|
|
106
|
+
best_bounds = (x_start, x_end, y_start, y_end, z_start, z_end)
|
|
107
|
+
idx += 1
|
|
108
|
+
|
|
109
|
+
# If no new voxels are covered, stop
|
|
110
|
+
if best_covered == 0:
|
|
111
|
+
break
|
|
112
|
+
|
|
113
|
+
# Add the best bounding box
|
|
114
|
+
c_x, c_y, c_z = [i.item() for i in potential_centers[best_center]]
|
|
115
|
+
bboxes.append(
|
|
116
|
+
[
|
|
117
|
+
[c_x - half_size[0], c_x + end_offset[0]],
|
|
118
|
+
[c_y - half_size[1], c_y + end_offset[1]],
|
|
119
|
+
[c_z - half_size[2], c_z + end_offset[2]],
|
|
120
|
+
]
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Mark voxels as covered, respecting the margin
|
|
124
|
+
x_s, x_e, y_s, y_e, z_s, z_e = best_bounds
|
|
125
|
+
uncovered[
|
|
126
|
+
x_s:x_e,
|
|
127
|
+
y_s:y_e,
|
|
128
|
+
z_s:z_e,
|
|
129
|
+
] = 0
|
|
130
|
+
|
|
131
|
+
# Remove the used center from potential_centers
|
|
132
|
+
potential_centers = potential_centers[uncovered[tuple(potential_centers.T)] > 0]
|
|
133
|
+
|
|
134
|
+
# Step 5: Recursively cover remaining voxels using uncovered as the mask
|
|
135
|
+
if uncovered.any():
|
|
136
|
+
if True: # uncovered.sum() < np.prod([i // 3 for i in bbox_size]):
|
|
137
|
+
# print('random fallback')
|
|
138
|
+
bboxes.extend(random_sampling_fallback(uncovered, bbox_size, margin, 10))
|
|
139
|
+
else:
|
|
140
|
+
remaining_bboxes = generate_bounding_boxes(
|
|
141
|
+
uncovered, bbox_size, [max(1, s // 2) for s in stride], margin, max_depth, current_depth + 1
|
|
142
|
+
)
|
|
143
|
+
bboxes.extend(remaining_bboxes)
|
|
144
|
+
|
|
145
|
+
return bboxes
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def random_sampling_fallback(mask: torch.Tensor, bbox_size=(192, 192, 192), margin=(10, 10, 10), n_samples: int = 25):
|
|
149
|
+
half_size = [bs // 2 for bs in bbox_size]
|
|
150
|
+
# Adjust end offsets to ensure full bbox_size (handles odd sizes)
|
|
151
|
+
end_offset = [bs - hs for bs, hs in zip(bbox_size, half_size)] # e.g., 193 - 96 = 97
|
|
152
|
+
|
|
153
|
+
bboxes = []
|
|
154
|
+
|
|
155
|
+
while mask.any():
|
|
156
|
+
indices = torch.nonzero(mask) # nx3
|
|
157
|
+
|
|
158
|
+
best_center = None
|
|
159
|
+
best_covered = 0
|
|
160
|
+
best_bounds = None
|
|
161
|
+
|
|
162
|
+
# Find the center that covers the most uncovered voxels
|
|
163
|
+
for i in range(n_samples):
|
|
164
|
+
idx = np.random.choice(len(indices))
|
|
165
|
+
center = indices[idx]
|
|
166
|
+
c_x, c_y, c_z = [int(i.item()) for i in center]
|
|
167
|
+
x_start = max(0, c_x - half_size[0] + margin[0])
|
|
168
|
+
x_end = min(mask.shape[0], c_x + end_offset[0] - margin[0]) # Use end_offset for odd sizes
|
|
169
|
+
y_start = max(0, c_y - half_size[1] + margin[1])
|
|
170
|
+
y_end = min(mask.shape[1], c_y + end_offset[1] - margin[1])
|
|
171
|
+
z_start = max(0, c_z - half_size[2] + margin[2])
|
|
172
|
+
z_end = min(mask.shape[2], c_z + end_offset[2] - margin[2])
|
|
173
|
+
|
|
174
|
+
num_covered = mask[x_start:x_end, y_start:y_end, z_start:z_end].sum().item()
|
|
175
|
+
if num_covered > best_covered:
|
|
176
|
+
best_covered = num_covered
|
|
177
|
+
best_center = center
|
|
178
|
+
best_bounds = (x_start, x_end, y_start, y_end, z_start, z_end)
|
|
179
|
+
|
|
180
|
+
# Add the best bounding box
|
|
181
|
+
c_x, c_y, c_z = [int(i.item()) for i in best_center]
|
|
182
|
+
bboxes.append(
|
|
183
|
+
[
|
|
184
|
+
[c_x - half_size[0], c_x + end_offset[0]],
|
|
185
|
+
[c_y - half_size[1], c_y + end_offset[1]],
|
|
186
|
+
[c_z - half_size[2], c_z + end_offset[2]],
|
|
187
|
+
]
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Mark voxels as covered, respecting the margin
|
|
191
|
+
x_s, x_e, y_s, y_e, z_s, z_e = best_bounds
|
|
192
|
+
mask[
|
|
193
|
+
x_s:x_e,
|
|
194
|
+
y_s:y_e,
|
|
195
|
+
z_s:z_e,
|
|
196
|
+
] = 0
|
|
197
|
+
return bboxes
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
if __name__ == "__main__":
|
|
201
|
+
times = []
|
|
202
|
+
torch.set_num_threads(8)
|
|
203
|
+
for _ in range(1):
|
|
204
|
+
st = time()
|
|
205
|
+
mask = torch.zeros((256, 256, 256), dtype=torch.uint8, device=0)
|
|
206
|
+
mask[50:150, 50:150, 50:150] = 1 # A cubic object
|
|
207
|
+
|
|
208
|
+
# Generate bounding boxes with an odd size to test
|
|
209
|
+
bboxes = random_sampling_fallback(
|
|
210
|
+
mask, bbox_size=(193, 193, 193), stride="auto", margin=(10, 10, 10) # Odd size
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Print results
|
|
214
|
+
print(f"Number of bounding boxes: {len(bboxes)}")
|
|
215
|
+
end = time()
|
|
216
|
+
times.append(end - st)
|
|
217
|
+
print(times)
|