nnInteractive 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nnInteractive/__init__.py +3 -0
  2. nnInteractive/inference/__init__.py +0 -0
  3. nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
  4. nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
  5. nnInteractive/inference/inference_session.py +1400 -0
  6. nnInteractive/interaction/__init__.py +0 -0
  7. nnInteractive/interaction/point.py +166 -0
  8. nnInteractive/supervoxel/setup.py +4 -0
  9. nnInteractive/supervoxel/src/metadata.py +118 -0
  10. nnInteractive/supervoxel/src/reader.py +175 -0
  11. nnInteractive/supervoxel/src/run.py +136 -0
  12. nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
  13. nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
  14. nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
  15. nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
  16. nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
  17. nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
  18. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
  19. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
  20. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
  21. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
  22. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
  23. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
  24. nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
  25. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
  26. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
  27. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
  28. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
  29. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
  30. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
  31. nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
  32. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
  33. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
  34. nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
  35. nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
  36. nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
  37. nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
  38. nnInteractive/supervoxel/src/sam2/setup.py +174 -0
  39. nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
  40. nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
  41. nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
  42. nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
  43. nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
  44. nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
  45. nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
  46. nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
  47. nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
  48. nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
  49. nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
  50. nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
  51. nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
  52. nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
  53. nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
  54. nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
  55. nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
  56. nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
  57. nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
  58. nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
  59. nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
  60. nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
  61. nnInteractive/supervoxel/src/supervoxel.py +198 -0
  62. nnInteractive/trainer/__init__.py +0 -0
  63. nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
  64. nnInteractive/utils/__init__.py +0 -0
  65. nnInteractive/utils/bboxes.py +217 -0
  66. nnInteractive/utils/checkpoint_cleansing.py +9 -0
  67. nnInteractive/utils/crop.py +268 -0
  68. nnInteractive/utils/erosion_dilation.py +48 -0
  69. nnInteractive/utils/inference_helpers.py +45 -0
  70. nnInteractive/utils/os_shennanigans.py +16 -0
  71. nnInteractive/utils/rounding.py +13 -0
  72. nninteractive-2.0.0.dist-info/METADATA +511 -0
  73. nninteractive-2.0.0.dist-info/RECORD +76 -0
  74. nninteractive-2.0.0.dist-info/WHEEL +5 -0
  75. nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
  76. nninteractive-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,3 @@
1
+ from importlib.metadata import version as _version
2
+
3
+ __version__ = _version("nnInteractive")
File without changes
@@ -0,0 +1,173 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ CVPR25 – Foundation Models for Interactive 3D Biomedical Image Segmentation
4
+ Skeleton inference script.
5
+
6
+ You only need to replace the `run_inference()` function with your model‑specific
7
+ code. Everything else takes care of
8
+ • reading the input image + prompts,
9
+ • passing the relevant information to your model,
10
+ • saving the prediction in the expected format.
11
+
12
+ During evaluation the script is called exactly once for every interaction
13
+ step (bbox prediction + 5 click refinements). The evaluator will overwrite
14
+ the same input file between calls, injecting updated clicks and the previous
15
+ prediction (`prev_pred`).
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import argparse
21
+ import os
22
+ import sys
23
+ from pathlib import Path
24
+ import numpy as np
25
+ import torch
26
+ from acvl_utils.cropping_and_padding.bounding_boxes import crop_and_pad_nd
27
+
28
+ from nnInteractive.inference.inference_session import nnInteractiveInferenceSession
29
+
30
+ from nnunetv2.utilities.helpers import empty_cache
31
+
32
+ # --------------------------------------------------------------------------- #
33
+ # === EDIT BELOW === #
34
+ # --------------------------------------------------------------------------- #
35
+
36
+
37
+ def run_inference(
38
+ image: np.ndarray,
39
+ spacing: tuple[float, float, float],
40
+ bbox: list[dict] | None,
41
+ clicks: list[dict] | None,
42
+ clicks_order: list[list[str]] | None,
43
+ prev_pred: np.ndarray | None,
44
+ ) -> np.ndarray:
45
+ """
46
+ Stub performing **one** forward pass of your model.
47
+
48
+ Parameters
49
+ ----------
50
+ image : (D, H, W) np.ndarray
51
+ Raw image volume (usually float32). *No preprocessing applied*.
52
+ spacing : (3,) tuple of float
53
+ Physical voxel spacing (z, y, x) in millimetres.
54
+ bbox : list of dict | None
55
+ Bounding‑box prompt(s). The dict structure is shown in the challenge
56
+ description; may be absent in refinement iterations.
57
+ clicks : list of dict | None
58
+ Fore‑ and background click dictionaries for every class.
59
+ prev_pred : (D, H, W) np.ndarray | None
60
+ Segmentation from the previous iteration. May be `None` for the first
61
+ call.
62
+
63
+ Returns
64
+ -------
65
+ seg : (D, H, W) np.ndarray, dtype=uint8
66
+ Multi‑class segmentation mask. Background **must** be 0;
67
+ classes start from 1 … N. Make sure dtype is `np.uint8`.
68
+ """
69
+ session = nnInteractiveInferenceSession(
70
+ device=torch.device("cuda", 0),
71
+ use_torch_compile=False,
72
+ verbose=False,
73
+ torch_n_threads=os.cpu_count(),
74
+ do_autozoom=True,
75
+ )
76
+ session.initialize_from_trained_model_folder(
77
+ model_training_output_dir=CHECKPOINT_DIR,
78
+ use_fold="all",
79
+ )
80
+ session.set_image(image[None].astype(np.float32))
81
+ target_buffer = torch.zeros(image.shape, dtype=torch.uint8, device="cpu")
82
+ session.set_target_buffer(target_buffer)
83
+ result = torch.zeros(image.shape, dtype=torch.uint8)
84
+ num_objects = len(bbox) if bbox is not None else len(clicks)
85
+ if bbox is not None and clicks is not None:
86
+ assert len(bbox) == len(clicks), (
87
+ "Both bboxs and clicks lists are provided but with different length "
88
+ "suggesting different number of objects. This is not supported by this script "
89
+ "and it was not communicated by the organizing team that such cases exist "
90
+ "or how they are supposed to be handled."
91
+ )
92
+ for oid in range(1, num_objects + 1):
93
+ # place previous segmentation
94
+ if prev_pred is not None:
95
+ session.add_initial_seg_interaction((prev_pred == oid).astype(np.uint8), run_prediction=False)
96
+ else:
97
+ session.reset_interactions()
98
+ if bbox is not None:
99
+ bbox_here = bbox[oid - 1]
100
+ bbox_here = [
101
+ [bbox_here["z_min"], bbox_here["z_max"] + 1],
102
+ [bbox_here["z_mid_y_min"], bbox_here["z_mid_y_max"] + 1],
103
+ [bbox_here["z_mid_x_min"], bbox_here["z_mid_x_max"] + 1],
104
+ ]
105
+ session.add_bbox_interaction(bbox_here, include_interaction=True, run_prediction=False)
106
+ if clicks is not None:
107
+ clicks_here = clicks[oid - 1]
108
+ clicks_order_here = clicks_order[oid - 1]
109
+ fg_ptr = bg_ptr = 0
110
+ for kind in clicks_order_here:
111
+ if kind == "fg":
112
+ click = clicks_here["fg"][fg_ptr]
113
+ fg_ptr += 1
114
+ else:
115
+ click = clicks_here["bg"][bg_ptr]
116
+ bg_ptr += 1
117
+
118
+ print(f"Class {oid}: {kind} click at {click}")
119
+ session.add_point_interaction(click, include_interaction=kind == "fg", run_prediction=False)
120
+ # now run inference on the last interaction center
121
+ session.new_interaction_centers = [session.new_interaction_centers[-1]]
122
+ session.new_interaction_zoom_out_factors = [session.new_interaction_zoom_out_factors[-1]]
123
+ session._predict()
124
+ result[session.target_buffer > 0] = oid
125
+ del session
126
+ empty_cache(torch.device("cuda", 0))
127
+ return result.cpu().numpy()
128
+
129
+
130
+ # --------------------------------------------------------------------------- #
131
+ # === DO NOT EDIT BELOW === #
132
+ # --------------------------------------------------------------------------- #
133
+
134
+
135
+ def parse_args() -> argparse.Namespace:
136
+ p = argparse.ArgumentParser()
137
+ p.add_argument("--case_path", required=True, help="Path to the input *.npz")
138
+ p.add_argument("--save_path", required=True, help="Path to write output *.npz")
139
+ return p.parse_args()
140
+
141
+
142
+ # Adapt this to your checkpoint directory (relative to the script)
143
+ CHECKPOINT_DIR = "checkpoint_folder"
144
+
145
+
146
+ def main() -> None:
147
+ args = parse_args()
148
+ case_path = Path(args.case_path)
149
+ save_path = Path(args.save_path)
150
+
151
+ if not case_path.is_file():
152
+ sys.exit(f"[predict.py] ERROR: {case_path} not found.")
153
+
154
+ # ---------------------- Load input & prompts -------------------------- #
155
+ data = np.load(case_path, allow_pickle=True)
156
+ image = data["imgs"]
157
+ spacing = tuple(data["spacing"])
158
+ bbox = data.get("boxes") # bounding boxes
159
+ clicks = data.get("clicks") # fg/bg clicks per class
160
+ clicks_order = data.get("clicks_order") # order of click types
161
+ prev_pred = data.get("prev_pred") # from last iteration
162
+
163
+ # --------------------------- Inference -------------------------------- #
164
+ seg = run_inference(image, spacing, bbox, clicks, clicks_order, prev_pred)
165
+
166
+ # ------------------------- Save prediction ---------------------------- #
167
+ save_path.parent.mkdir(parents=True, exist_ok=True)
168
+ np.savez_compressed(save_path, segs=seg.astype(np.uint8))
169
+ print(f"[predict.py] Saved prediction to {save_path}")
170
+
171
+
172
+ if __name__ == "__main__":
173
+ main()