nnInteractive 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nnInteractive/__init__.py +3 -0
- nnInteractive/inference/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
- nnInteractive/inference/inference_session.py +1400 -0
- nnInteractive/interaction/__init__.py +0 -0
- nnInteractive/interaction/point.py +166 -0
- nnInteractive/supervoxel/setup.py +4 -0
- nnInteractive/supervoxel/src/metadata.py +118 -0
- nnInteractive/supervoxel/src/reader.py +175 -0
- nnInteractive/supervoxel/src/run.py +136 -0
- nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
- nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
- nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
- nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
- nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
- nnInteractive/supervoxel/src/sam2/setup.py +174 -0
- nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
- nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
- nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
- nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
- nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
- nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
- nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
- nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
- nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
- nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
- nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
- nnInteractive/supervoxel/src/supervoxel.py +198 -0
- nnInteractive/trainer/__init__.py +0 -0
- nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
- nnInteractive/utils/__init__.py +0 -0
- nnInteractive/utils/bboxes.py +217 -0
- nnInteractive/utils/checkpoint_cleansing.py +9 -0
- nnInteractive/utils/crop.py +268 -0
- nnInteractive/utils/erosion_dilation.py +48 -0
- nnInteractive/utils/inference_helpers.py +45 -0
- nnInteractive/utils/os_shennanigans.py +16 -0
- nnInteractive/utils/rounding.py +13 -0
- nninteractive-2.0.0.dist-info/METADATA +511 -0
- nninteractive-2.0.0.dist-info/RECORD +76 -0
- nninteractive-2.0.0.dist-info/WHEEL +5 -0
- nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
- nninteractive-2.0.0.dist-info/top_level.txt +1 -0
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
CVPR25 – Foundation Models for Interactive 3D Biomedical Image Segmentation
|
|
4
|
+
Skeleton inference script.
|
|
5
|
+
|
|
6
|
+
You only need to replace the `run_inference()` function with your model‑specific
|
|
7
|
+
code. Everything else takes care of
|
|
8
|
+
• reading the input image + prompts,
|
|
9
|
+
• passing the relevant information to your model,
|
|
10
|
+
• saving the prediction in the expected format.
|
|
11
|
+
|
|
12
|
+
During evaluation the script is called exactly once for every interaction
|
|
13
|
+
step (bbox prediction + 5 click refinements). The evaluator will overwrite
|
|
14
|
+
the same input file between calls, injecting updated clicks and the previous
|
|
15
|
+
prediction (`prev_pred`).
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import argparse
|
|
21
|
+
import os
|
|
22
|
+
import sys
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
import numpy as np
|
|
25
|
+
import torch
|
|
26
|
+
from acvl_utils.cropping_and_padding.bounding_boxes import crop_and_pad_nd
|
|
27
|
+
|
|
28
|
+
from nnInteractive.inference.inference_session import nnInteractiveInferenceSession
|
|
29
|
+
|
|
30
|
+
from nnunetv2.utilities.helpers import empty_cache
|
|
31
|
+
|
|
32
|
+
# --------------------------------------------------------------------------- #
|
|
33
|
+
# === EDIT BELOW === #
|
|
34
|
+
# --------------------------------------------------------------------------- #
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def run_inference(
|
|
38
|
+
image: np.ndarray,
|
|
39
|
+
spacing: tuple[float, float, float],
|
|
40
|
+
bbox: list[dict] | None,
|
|
41
|
+
clicks: list[dict] | None,
|
|
42
|
+
clicks_order: list[list[str]] | None,
|
|
43
|
+
prev_pred: np.ndarray | None,
|
|
44
|
+
) -> np.ndarray:
|
|
45
|
+
"""
|
|
46
|
+
Stub performing **one** forward pass of your model.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
image : (D, H, W) np.ndarray
|
|
51
|
+
Raw image volume (usually float32). *No preprocessing applied*.
|
|
52
|
+
spacing : (3,) tuple of float
|
|
53
|
+
Physical voxel spacing (z, y, x) in millimetres.
|
|
54
|
+
bbox : list of dict | None
|
|
55
|
+
Bounding‑box prompt(s). The dict structure is shown in the challenge
|
|
56
|
+
description; may be absent in refinement iterations.
|
|
57
|
+
clicks : list of dict | None
|
|
58
|
+
Fore‑ and background click dictionaries for every class.
|
|
59
|
+
prev_pred : (D, H, W) np.ndarray | None
|
|
60
|
+
Segmentation from the previous iteration. May be `None` for the first
|
|
61
|
+
call.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
seg : (D, H, W) np.ndarray, dtype=uint8
|
|
66
|
+
Multi‑class segmentation mask. Background **must** be 0;
|
|
67
|
+
classes start from 1 … N. Make sure dtype is `np.uint8`.
|
|
68
|
+
"""
|
|
69
|
+
session = nnInteractiveInferenceSession(
|
|
70
|
+
device=torch.device("cuda", 0),
|
|
71
|
+
use_torch_compile=False,
|
|
72
|
+
verbose=False,
|
|
73
|
+
torch_n_threads=os.cpu_count(),
|
|
74
|
+
do_autozoom=True,
|
|
75
|
+
)
|
|
76
|
+
session.initialize_from_trained_model_folder(
|
|
77
|
+
model_training_output_dir=CHECKPOINT_DIR,
|
|
78
|
+
use_fold="all",
|
|
79
|
+
)
|
|
80
|
+
session.set_image(image[None].astype(np.float32))
|
|
81
|
+
target_buffer = torch.zeros(image.shape, dtype=torch.uint8, device="cpu")
|
|
82
|
+
session.set_target_buffer(target_buffer)
|
|
83
|
+
result = torch.zeros(image.shape, dtype=torch.uint8)
|
|
84
|
+
num_objects = len(bbox) if bbox is not None else len(clicks)
|
|
85
|
+
if bbox is not None and clicks is not None:
|
|
86
|
+
assert len(bbox) == len(clicks), (
|
|
87
|
+
"Both bboxs and clicks lists are provided but with different length "
|
|
88
|
+
"suggesting different number of objects. This is not supported by this script "
|
|
89
|
+
"and it was not communicated by the organizing team that such cases exist "
|
|
90
|
+
"or how they are supposed to be handled."
|
|
91
|
+
)
|
|
92
|
+
for oid in range(1, num_objects + 1):
|
|
93
|
+
# place previous segmentation
|
|
94
|
+
if prev_pred is not None:
|
|
95
|
+
session.add_initial_seg_interaction((prev_pred == oid).astype(np.uint8), run_prediction=False)
|
|
96
|
+
else:
|
|
97
|
+
session.reset_interactions()
|
|
98
|
+
if bbox is not None:
|
|
99
|
+
bbox_here = bbox[oid - 1]
|
|
100
|
+
bbox_here = [
|
|
101
|
+
[bbox_here["z_min"], bbox_here["z_max"] + 1],
|
|
102
|
+
[bbox_here["z_mid_y_min"], bbox_here["z_mid_y_max"] + 1],
|
|
103
|
+
[bbox_here["z_mid_x_min"], bbox_here["z_mid_x_max"] + 1],
|
|
104
|
+
]
|
|
105
|
+
session.add_bbox_interaction(bbox_here, include_interaction=True, run_prediction=False)
|
|
106
|
+
if clicks is not None:
|
|
107
|
+
clicks_here = clicks[oid - 1]
|
|
108
|
+
clicks_order_here = clicks_order[oid - 1]
|
|
109
|
+
fg_ptr = bg_ptr = 0
|
|
110
|
+
for kind in clicks_order_here:
|
|
111
|
+
if kind == "fg":
|
|
112
|
+
click = clicks_here["fg"][fg_ptr]
|
|
113
|
+
fg_ptr += 1
|
|
114
|
+
else:
|
|
115
|
+
click = clicks_here["bg"][bg_ptr]
|
|
116
|
+
bg_ptr += 1
|
|
117
|
+
|
|
118
|
+
print(f"Class {oid}: {kind} click at {click}")
|
|
119
|
+
session.add_point_interaction(click, include_interaction=kind == "fg", run_prediction=False)
|
|
120
|
+
# now run inference on the last interaction center
|
|
121
|
+
session.new_interaction_centers = [session.new_interaction_centers[-1]]
|
|
122
|
+
session.new_interaction_zoom_out_factors = [session.new_interaction_zoom_out_factors[-1]]
|
|
123
|
+
session._predict()
|
|
124
|
+
result[session.target_buffer > 0] = oid
|
|
125
|
+
del session
|
|
126
|
+
empty_cache(torch.device("cuda", 0))
|
|
127
|
+
return result.cpu().numpy()
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# --------------------------------------------------------------------------- #
|
|
131
|
+
# === DO NOT EDIT BELOW === #
|
|
132
|
+
# --------------------------------------------------------------------------- #
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def parse_args() -> argparse.Namespace:
|
|
136
|
+
p = argparse.ArgumentParser()
|
|
137
|
+
p.add_argument("--case_path", required=True, help="Path to the input *.npz")
|
|
138
|
+
p.add_argument("--save_path", required=True, help="Path to write output *.npz")
|
|
139
|
+
return p.parse_args()
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
# Adapt this to your checkpoint directory (relative to the script)
|
|
143
|
+
CHECKPOINT_DIR = "checkpoint_folder"
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def main() -> None:
|
|
147
|
+
args = parse_args()
|
|
148
|
+
case_path = Path(args.case_path)
|
|
149
|
+
save_path = Path(args.save_path)
|
|
150
|
+
|
|
151
|
+
if not case_path.is_file():
|
|
152
|
+
sys.exit(f"[predict.py] ERROR: {case_path} not found.")
|
|
153
|
+
|
|
154
|
+
# ---------------------- Load input & prompts -------------------------- #
|
|
155
|
+
data = np.load(case_path, allow_pickle=True)
|
|
156
|
+
image = data["imgs"]
|
|
157
|
+
spacing = tuple(data["spacing"])
|
|
158
|
+
bbox = data.get("boxes") # bounding boxes
|
|
159
|
+
clicks = data.get("clicks") # fg/bg clicks per class
|
|
160
|
+
clicks_order = data.get("clicks_order") # order of click types
|
|
161
|
+
prev_pred = data.get("prev_pred") # from last iteration
|
|
162
|
+
|
|
163
|
+
# --------------------------- Inference -------------------------------- #
|
|
164
|
+
seg = run_inference(image, spacing, bbox, clicks, clicks_order, prev_pred)
|
|
165
|
+
|
|
166
|
+
# ------------------------- Save prediction ---------------------------- #
|
|
167
|
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
168
|
+
np.savez_compressed(save_path, segs=seg.astype(np.uint8))
|
|
169
|
+
print(f"[predict.py] Saved prediction to {save_path}")
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
if __name__ == "__main__":
|
|
173
|
+
main()
|