nnInteractive 1.1.3__tar.gz → 2.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nninteractive-1.1.3 → nninteractive-2.0.0}/PKG-INFO +21 -12
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/__init__.py +1 -1
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +32 -25
- nninteractive-2.0.0/nnInteractive/inference/inference_session.py +1400 -0
- nninteractive-2.0.0/nnInteractive/interaction/point.py +166 -0
- nninteractive-2.0.0/nnInteractive/supervoxel/src/metadata.py +118 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/reader.py +27 -23
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/run.py +32 -14
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/__init__.py +1 -3
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +10 -30
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +2 -8
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +1 -3
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +6 -18
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +1 -3
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +2 -6
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +1 -3
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +1 -3
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +8 -30
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +10 -31
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +4 -12
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +9 -27
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +17 -49
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +4 -12
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +16 -49
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +28 -80
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +24 -71
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +6 -22
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +10 -13
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +5 -15
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +2 -6
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +17 -64
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +1 -3
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +4 -12
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +9 -27
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +1 -3
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +5 -16
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/loss_fns.py +8 -25
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/model/sam2.py +13 -39
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/optimizer.py +19 -59
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +3 -9
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/train.py +17 -55
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/trainer.py +32 -94
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +16 -49
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +6 -19
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +7 -23
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/utils/logger.py +5 -15
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +2 -15
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/supervoxel.py +40 -36
- nninteractive-2.0.0/nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/utils/bboxes.py +38 -38
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/utils/checkpoint_cleansing.py +3 -4
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/utils/crop.py +101 -40
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/utils/erosion_dilation.py +3 -6
- nninteractive-2.0.0/nnInteractive/utils/inference_helpers.py +45 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/utils/os_shennanigans.py +2 -1
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/utils/rounding.py +2 -1
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive.egg-info/PKG-INFO +21 -12
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive.egg-info/SOURCES.txt +1 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive.egg-info/requires.txt +2 -2
- {nninteractive-1.1.3 → nninteractive-2.0.0}/pyproject.toml +3 -3
- {nninteractive-1.1.3 → nninteractive-2.0.0}/readme.md +18 -9
- nninteractive-1.1.3/nnInteractive/inference/inference_session.py +0 -787
- nninteractive-1.1.3/nnInteractive/interaction/point.py +0 -114
- nninteractive-1.1.3/nnInteractive/supervoxel/src/metadata.py +0 -107
- nninteractive-1.1.3/nnInteractive/trainer/nnInteractiveTrainer.py +0 -25
- {nninteractive-1.1.3 → nninteractive-2.0.0}/LICENSE +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/inference/__init__.py +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/interaction/__init__.py +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/setup.py +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/__init__.py +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/setup.py +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/__init__.py +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/model/__init__.py +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/trainer/__init__.py +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/utils/__init__.py +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive.egg-info/dependency_links.txt +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive.egg-info/top_level.txt +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/setup.cfg +0 -0
- {nninteractive-1.1.3 → nninteractive-2.0.0}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nnInteractive
|
|
3
|
-
Version:
|
|
3
|
+
Version: 2.0.0
|
|
4
4
|
Summary: Inference code for nnInteractive
|
|
5
5
|
Author: Helmholtz Imaging Applied Computer Vision Lab
|
|
6
6
|
Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
|
|
@@ -219,8 +219,8 @@ Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
|
|
|
219
219
|
Requires-Python: >=3.10
|
|
220
220
|
Description-Content-Type: text/markdown
|
|
221
221
|
License-File: LICENSE
|
|
222
|
-
Requires-Dist: nnunetv2>=2.
|
|
223
|
-
Requires-Dist: torch
|
|
222
|
+
Requires-Dist: nnunetv2>=2.7.0
|
|
223
|
+
Requires-Dist: torch!=2.9.*,>=2.1.2
|
|
224
224
|
Requires-Dist: acvl-utils<0.3,>=0.2.3
|
|
225
225
|
Requires-Dist: batchgenerators>=0.25.1
|
|
226
226
|
Provides-Extra: dev
|
|
@@ -236,19 +236,27 @@ Dynamic: license-file
|
|
|
236
236
|
|
|
237
237
|
This repository provides the official Python backend for `nnInteractive`, a state-of-the-art framework for 3D promptable segmentation. It is designed for seamless integration into Python-based workflows—ideal for researchers, developers, and power users working directly with code.
|
|
238
238
|
|
|
239
|
-
|
|
240
|
-
> There is a known issue with **PyTorch 2.9.0** causing **OOM errors during inference** in `nnInteractive` (related to 3D convolutions — see the PyTorch issue [here](https://github.com/pytorch/pytorch/issues/166122)).
|
|
241
|
-
> **Until this is resolved, please use PyTorch 2.8.0 or earlier.**
|
|
239
|
+
`nnInteractive` is also available through graphical viewers (GUI) for those who prefer a visual workflow.
|
|
242
240
|
|
|
241
|
+
### Recommended integrations (developed and maintained by our team)
|
|
243
242
|
|
|
244
|
-
|
|
243
|
+
<div align="center">
|
|
244
|
+
|
|
245
|
+
| **<div align="center">[napari plugin](https://github.com/MIC-DKFZ/napari-nninteractive)</div>** | **<div align="center">[MITK integration](https://www.mitk.org/)</div>** |
|
|
246
|
+
|-------------------|----------------------|
|
|
247
|
+
| [<img src="imgs/Logos/napari.jpg" height="200">](https://github.com/MIC-DKFZ/napari-nninteractive) | [<img src="imgs/Logos/mitk.jpg" height="200">](https://www.mitk.org/) |
|
|
248
|
+
|
|
249
|
+
</div>
|
|
250
|
+
|
|
251
|
+
### Community-driven integrations
|
|
245
252
|
|
|
253
|
+
Huge thanks to the community for contributing these integrations!
|
|
246
254
|
|
|
247
255
|
<div align="center">
|
|
248
256
|
|
|
249
|
-
| **<div align="center">[
|
|
250
|
-
|
|
251
|
-
| [<img src="imgs/Logos/
|
|
257
|
+
| **<div align="center">[3D Slicer extension](https://github.com/coendevente/SlicerNNInteractive)</div>** | **<div align="center">[ITK-SNAP extension](https://itksnap-dls.readthedocs.io/en/latest/quick_start.html)</div>** | **<div align="center">[OHIF integration](https://github.com/CCI-Bonn/OHIF-AI)</div>** |
|
|
258
|
+
|-------------------------|-------------------------|-------------------------|
|
|
259
|
+
| [<img src="imgs/Logos/3DSlicer.png" height="200">](https://github.com/coendevente/SlicerNNInteractive) | [<img src="imgs/Logos/snaplogo_sq.png" height="200">](https://itksnap-dls.readthedocs.io/en/latest/quick_start.html) | [<img src="imgs/Logos/ohif.png" height="200">](https://github.com/CCI-Bonn/OHIF-AI) |
|
|
252
260
|
|
|
253
261
|
</div>
|
|
254
262
|
|
|
@@ -264,7 +272,6 @@ This repository provides the official Python backend for `nnInteractive`, a stat
|
|
|
264
272
|
|
|
265
273
|
---
|
|
266
274
|
|
|
267
|
-
|
|
268
275
|
## What is nnInteractive?
|
|
269
276
|
|
|
270
277
|
> Isensee, F.\*, Rokuss, M.\*, Krämer, L.\*, Dinkelacker, S., Ravindran, A., Stritzke, F., Hamm, B., Wald, T., Langenberg, M., Ulrich, C., Deissler, J., Floca, R., & Maier-Hein, K. (2025). nnInteractive: Redefining 3D Promptable Segmentation. https://arxiv.org/abs/2503.08373 \
|
|
@@ -339,6 +346,9 @@ import SimpleITK as sitk
|
|
|
339
346
|
from huggingface_hub import snapshot_download # Install huggingface_hub if not already installed
|
|
340
347
|
|
|
341
348
|
# --- Download Trained Model Weights (~400MB) ---
|
|
349
|
+
# License reminder: The official nnInteractive checkpoint is licensed under
|
|
350
|
+
# Creative Commons Attribution Non Commercial Share Alike 4.0 (CC BY-NC-SA 4.0).
|
|
351
|
+
# See the License section of this readme!.
|
|
342
352
|
REPO_ID = "nnInteractive/nnInteractive"
|
|
343
353
|
MODEL_NAME = "nnInteractive_v1.0" # Updated models may be available in the future
|
|
344
354
|
DOWNLOAD_DIR = "/home/isensee/temp" # Specify the download directory
|
|
@@ -360,7 +370,6 @@ session = nnInteractiveInferenceSession(
|
|
|
360
370
|
verbose=False,
|
|
361
371
|
torch_n_threads=os.cpu_count(), # Use available CPU cores
|
|
362
372
|
do_autozoom=True, # Enables AutoZoom for better patching
|
|
363
|
-
use_pinned_memory=True, # Optimizes GPU memory transfers
|
|
364
373
|
)
|
|
365
374
|
|
|
366
375
|
# Load the trained model
|
|
@@ -14,6 +14,7 @@ step (bbox prediction + 5 click refinements). The evaluator will overwrite
|
|
|
14
14
|
the same input file between calls, injecting updated clicks and the previous
|
|
15
15
|
prediction (`prev_pred`).
|
|
16
16
|
"""
|
|
17
|
+
|
|
17
18
|
from __future__ import annotations
|
|
18
19
|
|
|
19
20
|
import argparse
|
|
@@ -28,11 +29,11 @@ from nnInteractive.inference.inference_session import nnInteractiveInferenceSess
|
|
|
28
29
|
|
|
29
30
|
from nnunetv2.utilities.helpers import empty_cache
|
|
30
31
|
|
|
31
|
-
|
|
32
32
|
# --------------------------------------------------------------------------- #
|
|
33
33
|
# === EDIT BELOW === #
|
|
34
34
|
# --------------------------------------------------------------------------- #
|
|
35
35
|
|
|
36
|
+
|
|
36
37
|
def run_inference(
|
|
37
38
|
image: np.ndarray,
|
|
38
39
|
spacing: tuple[float, float, float],
|
|
@@ -66,27 +67,28 @@ def run_inference(
|
|
|
66
67
|
classes start from 1 … N. Make sure dtype is `np.uint8`.
|
|
67
68
|
"""
|
|
68
69
|
session = nnInteractiveInferenceSession(
|
|
69
|
-
device=torch.device(
|
|
70
|
+
device=torch.device("cuda", 0),
|
|
70
71
|
use_torch_compile=False,
|
|
71
72
|
verbose=False,
|
|
72
73
|
torch_n_threads=os.cpu_count(),
|
|
73
74
|
do_autozoom=True,
|
|
74
|
-
use_pinned_memory=True
|
|
75
75
|
)
|
|
76
76
|
session.initialize_from_trained_model_folder(
|
|
77
77
|
model_training_output_dir=CHECKPOINT_DIR,
|
|
78
|
-
use_fold=
|
|
78
|
+
use_fold="all",
|
|
79
79
|
)
|
|
80
80
|
session.set_image(image[None].astype(np.float32))
|
|
81
|
-
target_buffer = torch.zeros(image.shape, dtype=torch.uint8, device=
|
|
81
|
+
target_buffer = torch.zeros(image.shape, dtype=torch.uint8, device="cpu")
|
|
82
82
|
session.set_target_buffer(target_buffer)
|
|
83
83
|
result = torch.zeros(image.shape, dtype=torch.uint8)
|
|
84
84
|
num_objects = len(bbox) if bbox is not None else len(clicks)
|
|
85
85
|
if bbox is not None and clicks is not None:
|
|
86
|
-
assert len(bbox) == len(clicks), (
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
86
|
+
assert len(bbox) == len(clicks), (
|
|
87
|
+
"Both bboxs and clicks lists are provided but with different length "
|
|
88
|
+
"suggesting different number of objects. This is not supported by this script "
|
|
89
|
+
"and it was not communicated by the organizing team that such cases exist "
|
|
90
|
+
"or how they are supposed to be handled."
|
|
91
|
+
)
|
|
90
92
|
for oid in range(1, num_objects + 1):
|
|
91
93
|
# place previous segmentation
|
|
92
94
|
if prev_pred is not None:
|
|
@@ -96,46 +98,50 @@ def run_inference(
|
|
|
96
98
|
if bbox is not None:
|
|
97
99
|
bbox_here = bbox[oid - 1]
|
|
98
100
|
bbox_here = [
|
|
99
|
-
[bbox_here[
|
|
100
|
-
[bbox_here[
|
|
101
|
-
[bbox_here[
|
|
102
|
-
|
|
101
|
+
[bbox_here["z_min"], bbox_here["z_max"] + 1],
|
|
102
|
+
[bbox_here["z_mid_y_min"], bbox_here["z_mid_y_max"] + 1],
|
|
103
|
+
[bbox_here["z_mid_x_min"], bbox_here["z_mid_x_max"] + 1],
|
|
104
|
+
]
|
|
103
105
|
session.add_bbox_interaction(bbox_here, include_interaction=True, run_prediction=False)
|
|
104
106
|
if clicks is not None:
|
|
105
107
|
clicks_here = clicks[oid - 1]
|
|
106
108
|
clicks_order_here = clicks_order[oid - 1]
|
|
107
109
|
fg_ptr = bg_ptr = 0
|
|
108
110
|
for kind in clicks_order_here:
|
|
109
|
-
if kind ==
|
|
110
|
-
click = clicks_here[
|
|
111
|
+
if kind == "fg":
|
|
112
|
+
click = clicks_here["fg"][fg_ptr]
|
|
111
113
|
fg_ptr += 1
|
|
112
114
|
else:
|
|
113
|
-
click = clicks_here[
|
|
115
|
+
click = clicks_here["bg"][bg_ptr]
|
|
114
116
|
bg_ptr += 1
|
|
115
117
|
|
|
116
118
|
print(f"Class {oid}: {kind} click at {click}")
|
|
117
|
-
session.add_point_interaction(click, include_interaction=kind ==
|
|
119
|
+
session.add_point_interaction(click, include_interaction=kind == "fg", run_prediction=False)
|
|
118
120
|
# now run inference on the last interaction center
|
|
119
121
|
session.new_interaction_centers = [session.new_interaction_centers[-1]]
|
|
120
122
|
session.new_interaction_zoom_out_factors = [session.new_interaction_zoom_out_factors[-1]]
|
|
121
123
|
session._predict()
|
|
122
124
|
result[session.target_buffer > 0] = oid
|
|
123
125
|
del session
|
|
124
|
-
empty_cache(torch.device(
|
|
126
|
+
empty_cache(torch.device("cuda", 0))
|
|
125
127
|
return result.cpu().numpy()
|
|
126
128
|
|
|
129
|
+
|
|
127
130
|
# --------------------------------------------------------------------------- #
|
|
128
131
|
# === DO NOT EDIT BELOW === #
|
|
129
132
|
# --------------------------------------------------------------------------- #
|
|
130
133
|
|
|
134
|
+
|
|
131
135
|
def parse_args() -> argparse.Namespace:
|
|
132
136
|
p = argparse.ArgumentParser()
|
|
133
137
|
p.add_argument("--case_path", required=True, help="Path to the input *.npz")
|
|
134
138
|
p.add_argument("--save_path", required=True, help="Path to write output *.npz")
|
|
135
139
|
return p.parse_args()
|
|
136
140
|
|
|
141
|
+
|
|
137
142
|
# Adapt this to your checkpoint directory (relative to the script)
|
|
138
|
-
CHECKPOINT_DIR =
|
|
143
|
+
CHECKPOINT_DIR = "checkpoint_folder"
|
|
144
|
+
|
|
139
145
|
|
|
140
146
|
def main() -> None:
|
|
141
147
|
args = parse_args()
|
|
@@ -147,12 +153,12 @@ def main() -> None:
|
|
|
147
153
|
|
|
148
154
|
# ---------------------- Load input & prompts -------------------------- #
|
|
149
155
|
data = np.load(case_path, allow_pickle=True)
|
|
150
|
-
image
|
|
151
|
-
spacing
|
|
152
|
-
bbox
|
|
153
|
-
clicks
|
|
156
|
+
image = data["imgs"]
|
|
157
|
+
spacing = tuple(data["spacing"])
|
|
158
|
+
bbox = data.get("boxes") # bounding boxes
|
|
159
|
+
clicks = data.get("clicks") # fg/bg clicks per class
|
|
154
160
|
clicks_order = data.get("clicks_order") # order of click types
|
|
155
|
-
prev_pred
|
|
161
|
+
prev_pred = data.get("prev_pred") # from last iteration
|
|
156
162
|
|
|
157
163
|
# --------------------------- Inference -------------------------------- #
|
|
158
164
|
seg = run_inference(image, spacing, bbox, clicks, clicks_order, prev_pred)
|
|
@@ -162,5 +168,6 @@ def main() -> None:
|
|
|
162
168
|
np.savez_compressed(save_path, segs=seg.astype(np.uint8))
|
|
163
169
|
print(f"[predict.py] Saved prediction to {save_path}")
|
|
164
170
|
|
|
171
|
+
|
|
165
172
|
if __name__ == "__main__":
|
|
166
|
-
main()
|
|
173
|
+
main()
|