nnInteractive 1.1.3__tar.gz → 2.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. {nninteractive-1.1.3 → nninteractive-2.0.0}/PKG-INFO +21 -12
  2. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/__init__.py +1 -1
  3. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +32 -25
  4. nninteractive-2.0.0/nnInteractive/inference/inference_session.py +1400 -0
  5. nninteractive-2.0.0/nnInteractive/interaction/point.py +166 -0
  6. nninteractive-2.0.0/nnInteractive/supervoxel/src/metadata.py +118 -0
  7. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/reader.py +27 -23
  8. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/run.py +32 -14
  9. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/__init__.py +1 -3
  10. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +10 -30
  11. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +2 -8
  12. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +1 -3
  13. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +6 -18
  14. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +1 -3
  15. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +2 -6
  16. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +1 -3
  17. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +1 -3
  18. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +8 -30
  19. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +10 -31
  20. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +4 -12
  21. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +9 -27
  22. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +17 -49
  23. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +4 -12
  24. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +16 -49
  25. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +28 -80
  26. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +24 -71
  27. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +6 -22
  28. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +10 -13
  29. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +5 -15
  30. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +2 -6
  31. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +17 -64
  32. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +1 -3
  33. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +4 -12
  34. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +9 -27
  35. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +1 -3
  36. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +5 -16
  37. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/loss_fns.py +8 -25
  38. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/model/sam2.py +13 -39
  39. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/optimizer.py +19 -59
  40. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +3 -9
  41. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/train.py +17 -55
  42. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/trainer.py +32 -94
  43. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +16 -49
  44. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +6 -19
  45. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +7 -23
  46. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/utils/logger.py +5 -15
  47. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +2 -15
  48. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/supervoxel.py +40 -36
  49. nninteractive-2.0.0/nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
  50. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/utils/bboxes.py +38 -38
  51. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/utils/checkpoint_cleansing.py +3 -4
  52. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/utils/crop.py +101 -40
  53. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/utils/erosion_dilation.py +3 -6
  54. nninteractive-2.0.0/nnInteractive/utils/inference_helpers.py +45 -0
  55. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/utils/os_shennanigans.py +2 -1
  56. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/utils/rounding.py +2 -1
  57. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive.egg-info/PKG-INFO +21 -12
  58. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive.egg-info/SOURCES.txt +1 -0
  59. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive.egg-info/requires.txt +2 -2
  60. {nninteractive-1.1.3 → nninteractive-2.0.0}/pyproject.toml +3 -3
  61. {nninteractive-1.1.3 → nninteractive-2.0.0}/readme.md +18 -9
  62. nninteractive-1.1.3/nnInteractive/inference/inference_session.py +0 -787
  63. nninteractive-1.1.3/nnInteractive/interaction/point.py +0 -114
  64. nninteractive-1.1.3/nnInteractive/supervoxel/src/metadata.py +0 -107
  65. nninteractive-1.1.3/nnInteractive/trainer/nnInteractiveTrainer.py +0 -25
  66. {nninteractive-1.1.3 → nninteractive-2.0.0}/LICENSE +0 -0
  67. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/inference/__init__.py +0 -0
  68. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
  69. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/interaction/__init__.py +0 -0
  70. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/setup.py +0 -0
  71. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/__init__.py +0 -0
  72. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +0 -0
  73. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +0 -0
  74. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +0 -0
  75. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +0 -0
  76. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/setup.py +0 -0
  77. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/__init__.py +0 -0
  78. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +0 -0
  79. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/model/__init__.py +0 -0
  80. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +0 -0
  81. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/trainer/__init__.py +0 -0
  82. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive/utils/__init__.py +0 -0
  83. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive.egg-info/dependency_links.txt +0 -0
  84. {nninteractive-1.1.3 → nninteractive-2.0.0}/nnInteractive.egg-info/top_level.txt +0 -0
  85. {nninteractive-1.1.3 → nninteractive-2.0.0}/setup.cfg +0 -0
  86. {nninteractive-1.1.3 → nninteractive-2.0.0}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nnInteractive
3
- Version: 1.1.3
3
+ Version: 2.0.0
4
4
  Summary: Inference code for nnInteractive
5
5
  Author: Helmholtz Imaging Applied Computer Vision Lab
6
6
  Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
@@ -219,8 +219,8 @@ Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
219
219
  Requires-Python: >=3.10
220
220
  Description-Content-Type: text/markdown
221
221
  License-File: LICENSE
222
- Requires-Dist: nnunetv2>=2.6
223
- Requires-Dist: torch<2.9.0,>=2.6
222
+ Requires-Dist: nnunetv2>=2.7.0
223
+ Requires-Dist: torch!=2.9.*,>=2.1.2
224
224
  Requires-Dist: acvl-utils<0.3,>=0.2.3
225
225
  Requires-Dist: batchgenerators>=0.25.1
226
226
  Provides-Extra: dev
@@ -236,19 +236,27 @@ Dynamic: license-file
236
236
 
237
237
  This repository provides the official Python backend for `nnInteractive`, a state-of-the-art framework for 3D promptable segmentation. It is designed for seamless integration into Python-based workflows—ideal for researchers, developers, and power users working directly with code.
238
238
 
239
- > ⚠️ **Temporary Compatibility Warning**
240
- > There is a known issue with **PyTorch 2.9.0** causing **OOM errors during inference** in `nnInteractive` (related to 3D convolutions — see the PyTorch issue [here](https://github.com/pytorch/pytorch/issues/166122)).
241
- > **Until this is resolved, please use PyTorch 2.8.0 or earlier.**
239
+ `nnInteractive` is also available through graphical viewers (GUI) for those who prefer a visual workflow.
242
240
 
241
+ ### Recommended integrations (developed and maintained by our team)
243
242
 
244
- `nnInteractive` is also available through graphical viewers (GUI) for those who prefer a visual workflow. The napari and MITK integrations are developed and maintained by our team. Thanks to the community for contributing the 3D Slicer extension!
243
+ <div align="center">
244
+
245
+ | **<div align="center">[napari plugin](https://github.com/MIC-DKFZ/napari-nninteractive)</div>** | **<div align="center">[MITK integration](https://www.mitk.org/)</div>** |
246
+ |-------------------|----------------------|
247
+ | [<img src="imgs/Logos/napari.jpg" height="200">](https://github.com/MIC-DKFZ/napari-nninteractive) | [<img src="imgs/Logos/mitk.jpg" height="200">](https://www.mitk.org/) |
248
+
249
+ </div>
250
+
251
+ ### Community-driven integrations
245
252
 
253
+ Huge thanks to the community for contributing these integrations!
246
254
 
247
255
  <div align="center">
248
256
 
249
- | **<div align="center">[napari plugin](https://github.com/MIC-DKFZ/napari-nninteractive)</div>** | **<div align="center">[MITK integration](https://www.mitk.org/wiki/MITK-nnInteractive)</div>** | **<div align="center">[3D Slicer extension](https://github.com/coendevente/SlicerNNInteractive)</div>** | **<div align="center">[ITK-SNAP extension](https://itksnap-dls.readthedocs.io/en/latest/quick_start.html)</div>** | **<div align="center">[OHIF integration](https://github.com/CCI-Bonn/OHIF-AI)</div>** |
250
- |-------------------|----------------------|-------------------------|-------------------------|-------------------------|
251
- | [<img src="imgs/Logos/napari.jpg" height="200">](https://github.com/MIC-DKFZ/napari-nninteractive) | [<img src="imgs/Logos/mitk.jpg" height="200">](https://www.mitk.org/wiki/MITK-nnInteractive) | [<img src="imgs/Logos/3DSlicer.png" height="200">](https://github.com/coendevente/SlicerNNInteractive) | [<img src="imgs/Logos/snaplogo_sq.png" height="200">](https://itksnap-dls.readthedocs.io/en/latest/quick_start.html) | [<img src="imgs/Logos/ohif.png" height="200">](https://github.com/CCI-Bonn/OHIF-AI) |
257
+ | **<div align="center">[3D Slicer extension](https://github.com/coendevente/SlicerNNInteractive)</div>** | **<div align="center">[ITK-SNAP extension](https://itksnap-dls.readthedocs.io/en/latest/quick_start.html)</div>** | **<div align="center">[OHIF integration](https://github.com/CCI-Bonn/OHIF-AI)</div>** |
258
+ |-------------------------|-------------------------|-------------------------|
259
+ | [<img src="imgs/Logos/3DSlicer.png" height="200">](https://github.com/coendevente/SlicerNNInteractive) | [<img src="imgs/Logos/snaplogo_sq.png" height="200">](https://itksnap-dls.readthedocs.io/en/latest/quick_start.html) | [<img src="imgs/Logos/ohif.png" height="200">](https://github.com/CCI-Bonn/OHIF-AI) |
252
260
 
253
261
  </div>
254
262
 
@@ -264,7 +272,6 @@ This repository provides the official Python backend for `nnInteractive`, a stat
264
272
 
265
273
  ---
266
274
 
267
-
268
275
  ## What is nnInteractive?
269
276
 
270
277
  > Isensee, F.\*, Rokuss, M.\*, Krämer, L.\*, Dinkelacker, S., Ravindran, A., Stritzke, F., Hamm, B., Wald, T., Langenberg, M., Ulrich, C., Deissler, J., Floca, R., & Maier-Hein, K. (2025). nnInteractive: Redefining 3D Promptable Segmentation. https://arxiv.org/abs/2503.08373 \
@@ -339,6 +346,9 @@ import SimpleITK as sitk
339
346
  from huggingface_hub import snapshot_download # Install huggingface_hub if not already installed
340
347
 
341
348
  # --- Download Trained Model Weights (~400MB) ---
349
+ # License reminder: The official nnInteractive checkpoint is licensed under
350
+ # Creative Commons Attribution Non Commercial Share Alike 4.0 (CC BY-NC-SA 4.0).
351
+ # See the License section of this readme!.
342
352
  REPO_ID = "nnInteractive/nnInteractive"
343
353
  MODEL_NAME = "nnInteractive_v1.0" # Updated models may be available in the future
344
354
  DOWNLOAD_DIR = "/home/isensee/temp" # Specify the download directory
@@ -360,7 +370,6 @@ session = nnInteractiveInferenceSession(
360
370
  verbose=False,
361
371
  torch_n_threads=os.cpu_count(), # Use available CPU cores
362
372
  do_autozoom=True, # Enables AutoZoom for better patching
363
- use_pinned_memory=True, # Optimizes GPU memory transfers
364
373
  )
365
374
 
366
375
  # Load the trained model
@@ -1,3 +1,3 @@
1
1
  from importlib.metadata import version as _version
2
2
 
3
- __version__ = _version("nnInteractive")
3
+ __version__ = _version("nnInteractive")
@@ -14,6 +14,7 @@ step (bbox prediction + 5 click refinements). The evaluator will overwrite
14
14
  the same input file between calls, injecting updated clicks and the previous
15
15
  prediction (`prev_pred`).
16
16
  """
17
+
17
18
  from __future__ import annotations
18
19
 
19
20
  import argparse
@@ -28,11 +29,11 @@ from nnInteractive.inference.inference_session import nnInteractiveInferenceSess
28
29
 
29
30
  from nnunetv2.utilities.helpers import empty_cache
30
31
 
31
-
32
32
  # --------------------------------------------------------------------------- #
33
33
  # === EDIT BELOW === #
34
34
  # --------------------------------------------------------------------------- #
35
35
 
36
+
36
37
  def run_inference(
37
38
  image: np.ndarray,
38
39
  spacing: tuple[float, float, float],
@@ -66,27 +67,28 @@ def run_inference(
66
67
  classes start from 1 … N. Make sure dtype is `np.uint8`.
67
68
  """
68
69
  session = nnInteractiveInferenceSession(
69
- device=torch.device('cuda', 0),
70
+ device=torch.device("cuda", 0),
70
71
  use_torch_compile=False,
71
72
  verbose=False,
72
73
  torch_n_threads=os.cpu_count(),
73
74
  do_autozoom=True,
74
- use_pinned_memory=True
75
75
  )
76
76
  session.initialize_from_trained_model_folder(
77
77
  model_training_output_dir=CHECKPOINT_DIR,
78
- use_fold='all',
78
+ use_fold="all",
79
79
  )
80
80
  session.set_image(image[None].astype(np.float32))
81
- target_buffer = torch.zeros(image.shape, dtype=torch.uint8, device='cpu')
81
+ target_buffer = torch.zeros(image.shape, dtype=torch.uint8, device="cpu")
82
82
  session.set_target_buffer(target_buffer)
83
83
  result = torch.zeros(image.shape, dtype=torch.uint8)
84
84
  num_objects = len(bbox) if bbox is not None else len(clicks)
85
85
  if bbox is not None and clicks is not None:
86
- assert len(bbox) == len(clicks), ('Both bboxs and clicks lists are provided but with different length '
87
- 'suggesting different number of objects. This is not supported by this script '
88
- 'and it was not communicated by the organizing team that such cases exist '
89
- 'or how they are supposed to be handled.')
86
+ assert len(bbox) == len(clicks), (
87
+ "Both bboxs and clicks lists are provided but with different length "
88
+ "suggesting different number of objects. This is not supported by this script "
89
+ "and it was not communicated by the organizing team that such cases exist "
90
+ "or how they are supposed to be handled."
91
+ )
90
92
  for oid in range(1, num_objects + 1):
91
93
  # place previous segmentation
92
94
  if prev_pred is not None:
@@ -96,46 +98,50 @@ def run_inference(
96
98
  if bbox is not None:
97
99
  bbox_here = bbox[oid - 1]
98
100
  bbox_here = [
99
- [bbox_here['z_min'], bbox_here['z_max'] + 1],
100
- [bbox_here['z_mid_y_min'], bbox_here['z_mid_y_max'] + 1],
101
- [bbox_here['z_mid_x_min'], bbox_here['z_mid_x_max'] + 1]
102
- ]
101
+ [bbox_here["z_min"], bbox_here["z_max"] + 1],
102
+ [bbox_here["z_mid_y_min"], bbox_here["z_mid_y_max"] + 1],
103
+ [bbox_here["z_mid_x_min"], bbox_here["z_mid_x_max"] + 1],
104
+ ]
103
105
  session.add_bbox_interaction(bbox_here, include_interaction=True, run_prediction=False)
104
106
  if clicks is not None:
105
107
  clicks_here = clicks[oid - 1]
106
108
  clicks_order_here = clicks_order[oid - 1]
107
109
  fg_ptr = bg_ptr = 0
108
110
  for kind in clicks_order_here:
109
- if kind == 'fg':
110
- click = clicks_here['fg'][fg_ptr]
111
+ if kind == "fg":
112
+ click = clicks_here["fg"][fg_ptr]
111
113
  fg_ptr += 1
112
114
  else:
113
- click = clicks_here['bg'][bg_ptr]
115
+ click = clicks_here["bg"][bg_ptr]
114
116
  bg_ptr += 1
115
117
 
116
118
  print(f"Class {oid}: {kind} click at {click}")
117
- session.add_point_interaction(click, include_interaction=kind == 'fg', run_prediction=False)
119
+ session.add_point_interaction(click, include_interaction=kind == "fg", run_prediction=False)
118
120
  # now run inference on the last interaction center
119
121
  session.new_interaction_centers = [session.new_interaction_centers[-1]]
120
122
  session.new_interaction_zoom_out_factors = [session.new_interaction_zoom_out_factors[-1]]
121
123
  session._predict()
122
124
  result[session.target_buffer > 0] = oid
123
125
  del session
124
- empty_cache(torch.device('cuda', 0))
126
+ empty_cache(torch.device("cuda", 0))
125
127
  return result.cpu().numpy()
126
128
 
129
+
127
130
  # --------------------------------------------------------------------------- #
128
131
  # === DO NOT EDIT BELOW === #
129
132
  # --------------------------------------------------------------------------- #
130
133
 
134
+
131
135
  def parse_args() -> argparse.Namespace:
132
136
  p = argparse.ArgumentParser()
133
137
  p.add_argument("--case_path", required=True, help="Path to the input *.npz")
134
138
  p.add_argument("--save_path", required=True, help="Path to write output *.npz")
135
139
  return p.parse_args()
136
140
 
141
+
137
142
  # Adapt this to your checkpoint directory (relative to the script)
138
- CHECKPOINT_DIR = 'checkpoint_folder'
143
+ CHECKPOINT_DIR = "checkpoint_folder"
144
+
139
145
 
140
146
  def main() -> None:
141
147
  args = parse_args()
@@ -147,12 +153,12 @@ def main() -> None:
147
153
 
148
154
  # ---------------------- Load input & prompts -------------------------- #
149
155
  data = np.load(case_path, allow_pickle=True)
150
- image = data["imgs"]
151
- spacing = tuple(data["spacing"])
152
- bbox = data.get("boxes") # bounding boxes
153
- clicks = data.get("clicks") # fg/bg clicks per class
156
+ image = data["imgs"]
157
+ spacing = tuple(data["spacing"])
158
+ bbox = data.get("boxes") # bounding boxes
159
+ clicks = data.get("clicks") # fg/bg clicks per class
154
160
  clicks_order = data.get("clicks_order") # order of click types
155
- prev_pred = data.get("prev_pred") # from last iteration
161
+ prev_pred = data.get("prev_pred") # from last iteration
156
162
 
157
163
  # --------------------------- Inference -------------------------------- #
158
164
  seg = run_inference(image, spacing, bbox, clicks, clicks_order, prev_pred)
@@ -162,5 +168,6 @@ def main() -> None:
162
168
  np.savez_compressed(save_path, segs=seg.astype(np.uint8))
163
169
  print(f"[predict.py] Saved prediction to {save_path}")
164
170
 
171
+
165
172
  if __name__ == "__main__":
166
- main()
173
+ main()