nnInteractive 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nnInteractive/__init__.py +3 -0
  2. nnInteractive/inference/__init__.py +0 -0
  3. nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
  4. nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
  5. nnInteractive/inference/inference_session.py +1400 -0
  6. nnInteractive/interaction/__init__.py +0 -0
  7. nnInteractive/interaction/point.py +166 -0
  8. nnInteractive/supervoxel/setup.py +4 -0
  9. nnInteractive/supervoxel/src/metadata.py +118 -0
  10. nnInteractive/supervoxel/src/reader.py +175 -0
  11. nnInteractive/supervoxel/src/run.py +136 -0
  12. nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
  13. nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
  14. nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
  15. nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
  16. nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
  17. nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
  18. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
  19. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
  20. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
  21. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
  22. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
  23. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
  24. nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
  25. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
  26. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
  27. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
  28. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
  29. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
  30. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
  31. nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
  32. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
  33. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
  34. nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
  35. nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
  36. nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
  37. nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
  38. nnInteractive/supervoxel/src/sam2/setup.py +174 -0
  39. nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
  40. nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
  41. nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
  42. nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
  43. nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
  44. nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
  45. nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
  46. nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
  47. nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
  48. nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
  49. nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
  50. nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
  51. nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
  52. nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
  53. nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
  54. nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
  55. nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
  56. nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
  57. nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
  58. nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
  59. nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
  60. nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
  61. nnInteractive/supervoxel/src/supervoxel.py +198 -0
  62. nnInteractive/trainer/__init__.py +0 -0
  63. nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
  64. nnInteractive/utils/__init__.py +0 -0
  65. nnInteractive/utils/bboxes.py +217 -0
  66. nnInteractive/utils/checkpoint_cleansing.py +9 -0
  67. nnInteractive/utils/crop.py +268 -0
  68. nnInteractive/utils/erosion_dilation.py +48 -0
  69. nnInteractive/utils/inference_helpers.py +45 -0
  70. nnInteractive/utils/os_shennanigans.py +16 -0
  71. nnInteractive/utils/rounding.py +13 -0
  72. nninteractive-2.0.0.dist-info/METADATA +511 -0
  73. nninteractive-2.0.0.dist-info/RECORD +76 -0
  74. nninteractive-2.0.0.dist-info/WHEEL +5 -0
  75. nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
  76. nninteractive-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,198 @@
1
+ import os
2
+ import torch
3
+ from typing import List
4
+ import numpy as np
5
+ from reader import NfitiReaderWriter, BloscReaderWriter
6
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
7
+ from sam2.sam2.build_sam import build_sam2_video_predictor
8
+ from tqdm import tqdm
9
+ import concurrent.futures
10
+ import cc3d
11
+
12
+
13
+ class SuperVoxelGenerator:
14
+ def __init__(self, input_dir, output_dir: str, config: dict):
15
+ """
16
+ SuperVoxelGenerator class constructor. This class is responsible for generating supervoxels
17
+ segmentation masks from a list of image paths.
18
+
19
+ Parameters:
20
+ input_dir (str): The folder containing the images to process.
21
+ output_dir (str): The directory where the segmentation masks will be saved. output_dir / name_of_image
22
+ config (dict): A dictionary containing the configuration parameters for the SuperVoxel generation.
23
+ """
24
+ self.input_dir = input_dir
25
+ self.output_dir = output_dir
26
+ self.file_format = config["file_format"]
27
+
28
+ self.reader_writer = {".nii.gz": NfitiReaderWriter(), ".b2nd": BloscReaderWriter()}[self.file_format]
29
+ print(f"Using {self.reader_writer} to read and write files")
30
+
31
+ # Get the list of files to process
32
+ self.list_of_files = [
33
+ f
34
+ for f in os.listdir(input_dir)
35
+ if f.endswith(config["file_format"]) and config["excluded_strings"] not in f
36
+ ]
37
+ self.config = config
38
+
39
+ self.sam = sam_model_registry["vit_h"](checkpoint=self.config["sam1_checkpoint"]).to("cuda")
40
+ self.mask_generator = SamAutomaticMaskGenerator(
41
+ model=self.sam,
42
+ points_per_side=48,
43
+ points_per_batch=256,
44
+ pred_iou_thresh=0.85,
45
+ stability_score_thresh=0.92,
46
+ box_nms_thresh=0.6,
47
+ crop_nms_thresh=0.6,
48
+ crop_n_layers=1,
49
+ crop_n_points_downscale_factor=2,
50
+ min_mask_region_area=192,
51
+ )
52
+ model_cfg = "sam2/configs/sam2.1/sam2.1_hiera_l.yaml"
53
+ self.sam2_predictor = build_sam2_video_predictor(model_cfg, ckpt_path=self.config["sam2_checkpoint"])
54
+
55
+ def sam2_propagation(self, sam2_predictor_state, image_data: np.ndarray, masks: List[np.ndarray], slice_idx: int):
56
+ """
57
+ Propagate the masks using SAM2
58
+
59
+ Parameters:
60
+ image_data (np.ndarray): The image data to process. Shape: (z, y, x)
61
+ masks (List[np.ndarray]): A list of masks to propagate. Shape: (z, y, x)
62
+ slice_idx (int): The index of the slice which contains the masks
63
+ """
64
+ propagated_masks = np.zeros((len(masks), *image_data.shape))
65
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
66
+
67
+ for obj_id, m in enumerate(masks):
68
+ # add new prompts and instantly get the output on the same frame
69
+ frame_idx, object_ids, _ = self.sam2_predictor.add_new_mask(
70
+ sam2_predictor_state, slice_idx, obj_id, m.astype(np.int8)
71
+ )
72
+
73
+ # propagate the prompts to get masklets throughout the video
74
+ for frame_idx, object_ids, out_masks in self.sam2_predictor.propagate_in_video(sam2_predictor_state):
75
+ for obj_n, out_mask in zip(object_ids, out_masks):
76
+ propagated_masks[obj_n, frame_idx] = (out_mask > 0).detach().cpu().numpy()[0].astype(np.uint8)
77
+
78
+ # reset state to predict in other direction
79
+ self.sam2_predictor.reset_state(sam2_predictor_state)
80
+
81
+ # flip images order
82
+ sam2_predictor_state["images"] = torch.flip(sam2_predictor_state["images"], dims=(0,))
83
+ max_frame = sam2_predictor_state["images"].shape[0] - 1
84
+
85
+ for obj_id, m in enumerate(masks):
86
+ # add new prompts and instantly get the output on the same frame
87
+ frame_idx, object_ids, _ = self.sam2_predictor.add_new_mask(
88
+ sam2_predictor_state, max_frame - slice_idx, obj_id, m.astype(np.int8)
89
+ )
90
+
91
+ # propagate the prompts to get masklets throughout the video
92
+ for frame_idx, object_ids, out_masks in self.sam2_predictor.propagate_in_video(sam2_predictor_state):
93
+ for obj_n, out_mask in zip(object_ids, out_masks):
94
+ propagated_masks[obj_n, max_frame - frame_idx] = (
95
+ (out_mask > 0).detach().cpu().numpy()[0].astype(np.uint8)
96
+ )
97
+
98
+ return propagated_masks
99
+
100
+ def remove_other_components(self, masks, frame_idx):
101
+ """
102
+ Remove components other than the one overlaping with the seed mask in the given frame
103
+
104
+ Parameters:
105
+ mask (np.ndarray): The mask to process. Shape: (n, z, y, x)
106
+ frame_idx (int): The index of the frame
107
+ """
108
+ filtered_masks = []
109
+ for i, m in enumerate(masks):
110
+ m_cc, num_cc = cc3d.connected_components(m, connectivity=26, binary_image=True, return_N=True)
111
+
112
+ # keep all the components that overlap with the seed mask
113
+ filtered_components = []
114
+ for n in range(1, num_cc + 1):
115
+ if np.sum(m_cc[frame_idx] == n) > 0:
116
+ filtered_components.append(n)
117
+ if len(filtered_components) != 0:
118
+ final_component = np.random.choice(filtered_components)
119
+ filtered_masks.append((m_cc == final_component).astype(np.uint8))
120
+ filtered_masks = np.stack(filtered_masks, axis=0)
121
+ return filtered_masks
122
+
123
+ def sam_supervoxel(self, image_data: np.ndarray):
124
+ """
125
+ Generate the supervoxels segmentation masks using SAM and SAM2
126
+
127
+ Parameters:
128
+ image_data (np.ndarray): The image data to process. Shape: (z, y, x)
129
+ """
130
+ # Normalize between 0 to 1 using 95 percentile
131
+ image_data = (image_data - np.percentile(image_data, 5)) / (
132
+ np.percentile(image_data, 95) - np.percentile(image_data, 5)
133
+ )
134
+ data_shape = image_data.shape
135
+
136
+ # Sample random slice with Gaussian probability
137
+ z_len = data_shape[0]
138
+ slice_probabilitys = np.exp(-np.linspace(-1, 1, data_shape[0]) ** 2 * 2)
139
+ slice_idx = np.random.choice(z_len, p=slice_probabilitys / slice_probabilitys.sum())
140
+
141
+ def generate_sam_masks():
142
+ return self.mask_generator.generate(image_data[slice_idx, ..., None].repeat(3, axis=2))
143
+
144
+ def init_sam2_predictor():
145
+ return self.sam2_predictor.init_state(image_data)
146
+
147
+ with concurrent.futures.ThreadPoolExecutor() as executor:
148
+ future_sam_masks = executor.submit(generate_sam_masks)
149
+ future_sam2_predictor = executor.submit(init_sam2_predictor)
150
+
151
+ masks = future_sam_masks.result()
152
+ sam2_predictor_state = future_sam2_predictor.result()
153
+
154
+ # Pick n masks randomly
155
+ selected_masks = (
156
+ np.random.choice(masks, size=self.config["masks_per_image"], replace=False)
157
+ if len(masks) > self.config["masks_per_image"]
158
+ else masks
159
+ )
160
+ selected_masks = [m["segmentation"] for m in selected_masks]
161
+
162
+ # Propagate the masks using SAM2
163
+ propagated_masks = self.sam2_propagation(sam2_predictor_state, image_data, selected_masks, slice_idx)
164
+ self.sam2_predictor.reset_state(sam2_predictor_state)
165
+
166
+ # Remove other components
167
+ propagated_masks = self.remove_other_components(propagated_masks, slice_idx)
168
+
169
+ # Binrize the masks
170
+ propagated_masks = (propagated_masks > 0).astype(np.uint8)
171
+
172
+ # viewer = napari.Viewer()
173
+ # viewer.add_image(image_data)
174
+ # viewer.add_labels(propagated_masks.astype(np.uint8))
175
+ # napari.run()
176
+
177
+ return propagated_masks
178
+
179
+ def process_images(self):
180
+ """
181
+ Process the images in the list_of_images and generate the supervoxels segmentation masks.
182
+ """
183
+ print(f"Processing {self.list_of_files} images")
184
+
185
+ for image_name in tqdm(self.list_of_files):
186
+ out_path = os.path.join(self.output_dir, image_name)
187
+ if os.path.exists(out_path):
188
+ continue
189
+
190
+ image_data, metadata = self.reader_writer.read(os.path.join(self.input_dir, image_name))
191
+
192
+ if len(image_data.shape) == 4:
193
+ image_data = image_data[0]
194
+ out_img = self.sam_supervoxel(image_data)
195
+
196
+ if os.path.exists(out_path):
197
+ continue
198
+ self.reader_writer.write(out_img, metadata, out_path)
File without changes
@@ -0,0 +1,24 @@
1
+ from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
2
+ from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
3
+ from torch import nn
4
+
5
+
6
+ class nnInteractiveTrainer_stub:
7
+ def __init__(self, *args, **kwargs):
8
+ pass
9
+
10
+ @staticmethod
11
+ def build_network_architecture(
12
+ plans_manager: PlansManager,
13
+ configuration_manager: ConfigurationManager,
14
+ num_input_channels: int,
15
+ num_output_channels: int,
16
+ enable_deep_supervision: bool = True,
17
+ ) -> nn.Module:
18
+ return nnUNetTrainer.build_network_architecture(
19
+ plans_manager,
20
+ configuration_manager,
21
+ num_input_channels,
22
+ 2, # nnunet handles one class segmentation still as CE so we need 2 outputs.
23
+ enable_deep_supervision,
24
+ )
File without changes
@@ -0,0 +1,217 @@
1
+ from time import time
2
+ from time import time
3
+ from typing import List, Union, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ def generate_bounding_boxes(
10
+ mask,
11
+ bbox_size=(192, 192, 192),
12
+ stride: Union[List[int], Tuple[int, int, int], str] = (16, 16, 16),
13
+ margin=(10, 10, 10),
14
+ max_depth=5,
15
+ current_depth=0,
16
+ ):
17
+ """
18
+ Generate overlapping bounding boxes to cover a 3D binary segmentation mask using PyTorch tensors.
19
+
20
+ Parameters:
21
+ - mask: 3D PyTorch tensor with values 0 or 1 (binary mask)
22
+ - bbox_size: Tuple or list of three integers specifying the size of bounding boxes per dimension (x, y, z)
23
+ - stride: Tuple or list of three integers specifying the stride for subsampling centers per dimension
24
+ - margin: Tuple or list of three integers specifying the margin to leave uncovered per dimension
25
+ - max_depth: Maximum recursion depth to prevent infinite recursion
26
+ - current_depth: Current recursion depth (used internally)
27
+
28
+ Returns:
29
+ - List of tuples [(min_coords, max_coords), ...], where min_coords and max_coords are lists [x, y, z] defining each box
30
+ as a half-open interval [min_coords, max_coords).
31
+ """
32
+ if not torch.any(mask):
33
+ return []
34
+
35
+ # Prevent infinite recursion
36
+ if current_depth > max_depth:
37
+ # print('random fallback due to max recursion depth')
38
+ return random_sampling_fallback(mask, bbox_size, margin, 25)
39
+
40
+ # Ensure bbox_size, stride, and margin are lists
41
+ bbox_size = list(bbox_size)
42
+ margin = list(margin)
43
+
44
+ # Compute half sizes for each dimension
45
+ half_size = [bs // 2 for bs in bbox_size]
46
+ # Adjust end offsets to ensure full bbox_size (handles odd sizes)
47
+ end_offset = [bs - hs for bs, hs in zip(bbox_size, half_size)] # e.g., 193 - 96 = 97
48
+
49
+ # Step 1: Find all object voxels
50
+ object_voxels = torch.nonzero(mask, as_tuple=False)
51
+ if object_voxels.numel() == 0:
52
+ return []
53
+
54
+ # Step 2: Compute the object's bounding box to limit potential centers
55
+ min_coords = object_voxels.min(dim=0)[0]
56
+ max_coords = object_voxels.max(dim=0)[0]
57
+
58
+ if isinstance(stride, str) and stride == "auto":
59
+ stride = [max(1, round((j.item() - i.item()) / 4)) for i, j in zip(min_coords, max_coords)]
60
+
61
+ stride = list(stride)
62
+ # print('stride', stride)
63
+ # print('bbox', [[i, j] for i, j in zip(min_coords, max_coords)])
64
+
65
+ # Step 3: Generate potential centers within the object's bounding box
66
+ potential_centers = []
67
+ for x in range(max(0, min_coords[0].item()), min(mask.shape[0], max_coords[0].item() + 1), stride[0]):
68
+ for y in range(max(0, min_coords[1].item()), min(mask.shape[1], max_coords[1].item() + 1), stride[1]):
69
+ for z in range(max(0, min_coords[2].item()), min(mask.shape[2], max_coords[2].item() + 1), stride[2]):
70
+ if mask[x, y, z]:
71
+ potential_centers.append([x, y, z])
72
+ # print(f'got {len(potential_centers)} center candidates')
73
+
74
+ if len(potential_centers) == 0:
75
+ return generate_bounding_boxes(
76
+ mask, bbox_size, [max(1, s // 2) for s in stride], margin, max_depth, current_depth + 1
77
+ )
78
+
79
+ potential_centers = torch.tensor(potential_centers, device=mask.device)
80
+
81
+ # Step 4: Greedy set cover algorithm
82
+ uncovered = mask.clone().byte() # Use byte tensor for efficiency
83
+ bboxes = []
84
+
85
+ while len(potential_centers) > 0 and uncovered.any():
86
+ best_center = None
87
+ best_covered = 0
88
+ best_bounds = None
89
+
90
+ # Find the center that covers the most uncovered voxels
91
+ idx = 0
92
+ while idx < len(potential_centers):
93
+ center = potential_centers[idx]
94
+ c_x, c_y, c_z = center
95
+ x_start = max(0, c_x - half_size[0] + margin[0])
96
+ x_end = min(mask.shape[0], c_x + end_offset[0] - margin[0]) # Use end_offset for odd sizes
97
+ y_start = max(0, c_y - half_size[1] + margin[1])
98
+ y_end = min(mask.shape[1], c_y + end_offset[1] - margin[1])
99
+ z_start = max(0, c_z - half_size[2] + margin[2])
100
+ z_end = min(mask.shape[2], c_z + end_offset[2] - margin[2])
101
+
102
+ num_covered = uncovered[x_start:x_end, y_start:y_end, z_start:z_end].sum().item()
103
+ if num_covered > best_covered:
104
+ best_covered = num_covered
105
+ best_center = idx
106
+ best_bounds = (x_start, x_end, y_start, y_end, z_start, z_end)
107
+ idx += 1
108
+
109
+ # If no new voxels are covered, stop
110
+ if best_covered == 0:
111
+ break
112
+
113
+ # Add the best bounding box
114
+ c_x, c_y, c_z = [i.item() for i in potential_centers[best_center]]
115
+ bboxes.append(
116
+ [
117
+ [c_x - half_size[0], c_x + end_offset[0]],
118
+ [c_y - half_size[1], c_y + end_offset[1]],
119
+ [c_z - half_size[2], c_z + end_offset[2]],
120
+ ]
121
+ )
122
+
123
+ # Mark voxels as covered, respecting the margin
124
+ x_s, x_e, y_s, y_e, z_s, z_e = best_bounds
125
+ uncovered[
126
+ x_s:x_e,
127
+ y_s:y_e,
128
+ z_s:z_e,
129
+ ] = 0
130
+
131
+ # Remove the used center from potential_centers
132
+ potential_centers = potential_centers[uncovered[tuple(potential_centers.T)] > 0]
133
+
134
+ # Step 5: Recursively cover remaining voxels using uncovered as the mask
135
+ if uncovered.any():
136
+ if True: # uncovered.sum() < np.prod([i // 3 for i in bbox_size]):
137
+ # print('random fallback')
138
+ bboxes.extend(random_sampling_fallback(uncovered, bbox_size, margin, 10))
139
+ else:
140
+ remaining_bboxes = generate_bounding_boxes(
141
+ uncovered, bbox_size, [max(1, s // 2) for s in stride], margin, max_depth, current_depth + 1
142
+ )
143
+ bboxes.extend(remaining_bboxes)
144
+
145
+ return bboxes
146
+
147
+
148
+ def random_sampling_fallback(mask: torch.Tensor, bbox_size=(192, 192, 192), margin=(10, 10, 10), n_samples: int = 25):
149
+ half_size = [bs // 2 for bs in bbox_size]
150
+ # Adjust end offsets to ensure full bbox_size (handles odd sizes)
151
+ end_offset = [bs - hs for bs, hs in zip(bbox_size, half_size)] # e.g., 193 - 96 = 97
152
+
153
+ bboxes = []
154
+
155
+ while mask.any():
156
+ indices = torch.nonzero(mask) # nx3
157
+
158
+ best_center = None
159
+ best_covered = 0
160
+ best_bounds = None
161
+
162
+ # Find the center that covers the most uncovered voxels
163
+ for i in range(n_samples):
164
+ idx = np.random.choice(len(indices))
165
+ center = indices[idx]
166
+ c_x, c_y, c_z = [int(i.item()) for i in center]
167
+ x_start = max(0, c_x - half_size[0] + margin[0])
168
+ x_end = min(mask.shape[0], c_x + end_offset[0] - margin[0]) # Use end_offset for odd sizes
169
+ y_start = max(0, c_y - half_size[1] + margin[1])
170
+ y_end = min(mask.shape[1], c_y + end_offset[1] - margin[1])
171
+ z_start = max(0, c_z - half_size[2] + margin[2])
172
+ z_end = min(mask.shape[2], c_z + end_offset[2] - margin[2])
173
+
174
+ num_covered = mask[x_start:x_end, y_start:y_end, z_start:z_end].sum().item()
175
+ if num_covered > best_covered:
176
+ best_covered = num_covered
177
+ best_center = center
178
+ best_bounds = (x_start, x_end, y_start, y_end, z_start, z_end)
179
+
180
+ # Add the best bounding box
181
+ c_x, c_y, c_z = [int(i.item()) for i in best_center]
182
+ bboxes.append(
183
+ [
184
+ [c_x - half_size[0], c_x + end_offset[0]],
185
+ [c_y - half_size[1], c_y + end_offset[1]],
186
+ [c_z - half_size[2], c_z + end_offset[2]],
187
+ ]
188
+ )
189
+
190
+ # Mark voxels as covered, respecting the margin
191
+ x_s, x_e, y_s, y_e, z_s, z_e = best_bounds
192
+ mask[
193
+ x_s:x_e,
194
+ y_s:y_e,
195
+ z_s:z_e,
196
+ ] = 0
197
+ return bboxes
198
+
199
+
200
+ if __name__ == "__main__":
201
+ times = []
202
+ torch.set_num_threads(8)
203
+ for _ in range(1):
204
+ st = time()
205
+ mask = torch.zeros((256, 256, 256), dtype=torch.uint8, device=0)
206
+ mask[50:150, 50:150, 50:150] = 1 # A cubic object
207
+
208
+ # Generate bounding boxes with an odd size to test
209
+ bboxes = random_sampling_fallback(
210
+ mask, bbox_size=(193, 193, 193), stride="auto", margin=(10, 10, 10) # Odd size
211
+ )
212
+
213
+ # Print results
214
+ print(f"Number of bounding boxes: {len(bboxes)}")
215
+ end = time()
216
+ times.append(end - st)
217
+ print(times)
@@ -0,0 +1,9 @@
1
+ import torch
2
+
3
+
4
+ def cleanse_checkpoint_for_release(checkpoint: str):
5
+ a = torch.load(checkpoint, weights_only=False)
6
+ del a["optimizer_state"]
7
+ del a["init_args"]["dataset_json"]
8
+ a["trainer_name"] = "nnInteractiveTrainer_stub"
9
+ torch.save(a, checkpoint)