nnInteractive 2.0.0__tar.gz → 2.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. {nninteractive-2.0.0 → nninteractive-2.2.0}/PKG-INFO +75 -5
  2. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/inference/inference_session.py +115 -11
  3. nninteractive-2.2.0/nnInteractive/inference/remote/__init__.py +11 -0
  4. nninteractive-2.2.0/nnInteractive/inference/remote/_protocol.py +27 -0
  5. nninteractive-2.2.0/nnInteractive/inference/remote/remote_session.py +528 -0
  6. nninteractive-2.2.0/nnInteractive/inference/remote/serialization.py +142 -0
  7. nninteractive-2.2.0/nnInteractive/inference/server/app.py +662 -0
  8. nninteractive-2.2.0/nnInteractive/inference/server/main.py +233 -0
  9. nninteractive-2.2.0/nnInteractive/utils/__init__.py +0 -0
  10. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive.egg-info/PKG-INFO +75 -5
  11. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive.egg-info/SOURCES.txt +8 -0
  12. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive.egg-info/dependency_links.txt +0 -0
  13. nninteractive-2.2.0/nnInteractive.egg-info/entry_points.txt +2 -0
  14. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive.egg-info/requires.txt +8 -0
  15. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive.egg-info/top_level.txt +0 -0
  16. {nninteractive-2.0.0 → nninteractive-2.2.0}/pyproject.toml +16 -1
  17. {nninteractive-2.0.0 → nninteractive-2.2.0}/readme.md +67 -4
  18. {nninteractive-2.0.0 → nninteractive-2.2.0}/LICENSE +0 -0
  19. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/__init__.py +0 -0
  20. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/inference/__init__.py +0 -0
  21. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
  22. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +0 -0
  23. {nninteractive-2.0.0/nnInteractive/interaction → nninteractive-2.2.0/nnInteractive/inference/server}/__init__.py +0 -0
  24. {nninteractive-2.0.0/nnInteractive/trainer → nninteractive-2.2.0/nnInteractive/interaction}/__init__.py +0 -0
  25. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/interaction/point.py +0 -0
  26. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/setup.py +0 -0
  27. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/metadata.py +0 -0
  28. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/reader.py +0 -0
  29. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/run.py +0 -0
  30. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/__init__.py +0 -0
  31. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/__init__.py +0 -0
  32. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +0 -0
  33. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +0 -0
  34. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +0 -0
  35. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +0 -0
  36. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +0 -0
  37. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +0 -0
  38. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +0 -0
  39. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +0 -0
  40. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +0 -0
  41. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +0 -0
  42. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +0 -0
  43. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +0 -0
  44. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +0 -0
  45. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +0 -0
  46. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +0 -0
  47. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +0 -0
  48. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +0 -0
  49. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +0 -0
  50. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +0 -0
  51. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +0 -0
  52. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +0 -0
  53. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +0 -0
  54. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +0 -0
  55. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +0 -0
  56. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/setup.py +0 -0
  57. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/__init__.py +0 -0
  58. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +0 -0
  59. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +0 -0
  60. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +0 -0
  61. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +0 -0
  62. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +0 -0
  63. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +0 -0
  64. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +0 -0
  65. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +0 -0
  66. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/loss_fns.py +0 -0
  67. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/model/__init__.py +0 -0
  68. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/model/sam2.py +0 -0
  69. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/optimizer.py +0 -0
  70. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +0 -0
  71. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/train.py +0 -0
  72. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/trainer.py +0 -0
  73. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +0 -0
  74. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +0 -0
  75. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +0 -0
  76. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +0 -0
  77. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/logger.py +0 -0
  78. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +0 -0
  79. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/supervoxel.py +0 -0
  80. {nninteractive-2.0.0/nnInteractive/utils → nninteractive-2.2.0/nnInteractive/trainer}/__init__.py +0 -0
  81. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/trainer/nnInteractiveTrainer.py +0 -0
  82. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/utils/bboxes.py +0 -0
  83. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/utils/checkpoint_cleansing.py +0 -0
  84. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/utils/crop.py +0 -0
  85. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/utils/erosion_dilation.py +0 -0
  86. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/utils/inference_helpers.py +0 -0
  87. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/utils/os_shennanigans.py +0 -0
  88. {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/utils/rounding.py +0 -0
  89. {nninteractive-2.0.0 → nninteractive-2.2.0}/setup.cfg +0 -0
  90. {nninteractive-2.0.0 → nninteractive-2.2.0}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nnInteractive
3
- Version: 2.0.0
3
+ Version: 2.2.0
4
4
  Summary: Inference code for nnInteractive
5
5
  Author: Helmholtz Imaging Applied Computer Vision Lab
6
6
  Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
@@ -223,10 +223,17 @@ Requires-Dist: nnunetv2>=2.7.0
223
223
  Requires-Dist: torch!=2.9.*,>=2.1.2
224
224
  Requires-Dist: acvl-utils<0.3,>=0.2.3
225
225
  Requires-Dist: batchgenerators>=0.25.1
226
+ Requires-Dist: fastapi>=0.110
227
+ Requires-Dist: uvicorn[standard]>=0.27
228
+ Requires-Dist: httpx>=0.27
229
+ Requires-Dist: blosc2
226
230
  Provides-Extra: dev
227
231
  Requires-Dist: black; extra == "dev"
228
232
  Requires-Dist: ruff; extra == "dev"
229
233
  Requires-Dist: pre-commit; extra == "dev"
234
+ Provides-Extra: client
235
+ Requires-Dist: httpx>=0.27; extra == "client"
236
+ Requires-Dist: blosc2; extra == "client"
230
237
  Dynamic: license-file
231
238
 
232
239
  <img src="imgs/nnInteractive_header_white.png">
@@ -417,14 +424,42 @@ session.add_point_interaction(POINT_COORDINATES, include_interaction=False)
417
424
  session.add_bbox_interaction(BBOX_COORDINATES, include_interaction=True)
418
425
 
419
426
  # Example: Add a scribble interaction
420
- # - A 3D image of the same shape as img where one slice (any axis-aligned orientation) contains a hand-drawn scribble.
421
427
  # - Background must be 0, and scribble must be 1.
422
428
  # - Use session.preferred_scribble_thickness for optimal results.
423
- session.add_scribble_interaction(SCRIBBLE_IMAGE, include_interaction=True)
429
+ #
430
+ # ✅ RECOMMENDED (v2): pass a small 2D crop plus its location.
431
+ # Scribbles live on a single axis-aligned slice, so one of the three bbox
432
+ # dimensions is always size 1 and the in-plane extent typically covers only
433
+ # a small region. The cropped array is ORDERS OF MAGNITUDE
434
+ # smaller than a full-volume mask for typical annotations, which makes this
435
+ # path dramatically faster. Please prefer this
436
+ # form in new integrations.
437
+ #
438
+ # SCRIBBLE_CROP.shape must equal the bbox size. INTERACTION_BBOX uses
439
+ # half-open intervals [[x1,x2],[y1,y2],[z1,z2]] in original-image coordinates.
440
+ # Example: a scribble drawn on axial slice z=64, covering x∈[100,140), y∈[80,150):
441
+ # SCRIBBLE_CROP = <ndarray of shape (40, 70, 1), values 0 or 1>
442
+ # INTERACTION_BBOX = [[100, 140], [80, 150], [64, 65]]
443
+ session.add_scribble_interaction(
444
+ SCRIBBLE_CROP,
445
+ include_interaction=True,
446
+ interaction_bbox=INTERACTION_BBOX,
447
+ )
448
+
449
+ # Legacy form (still supported, but discouraged): a 3D array matching the
450
+ # full original image shape with the scribble baked into one slice.
451
+ # session.add_scribble_interaction(SCRIBBLE_IMAGE, include_interaction=True)
424
452
 
425
453
  # Example: Add a lasso interaction
426
- # - Similarly to scribble a 3D image with a single slice containing a **closed contour** representing the selection.
427
- session.add_lasso_interaction(LASSO_IMAGE, include_interaction=True)
454
+ # - Like scribble but the single slice contains a **closed contour** for the selection.
455
+ # - Same recommendation applies: pass a 2D crop + interaction_bbox for a large speedup.
456
+ session.add_lasso_interaction(
457
+ LASSO_CROP,
458
+ include_interaction=True,
459
+ interaction_bbox=INTERACTION_BBOX,
460
+ )
461
+ # Legacy full-volume form (discouraged):
462
+ # session.add_lasso_interaction(LASSO_IMAGE, include_interaction=True)
428
463
 
429
464
  # You can combine any number of interactions as needed.
430
465
  # The model refines the segmentation result incrementally with each new interaction.
@@ -452,6 +487,41 @@ session.set_target_buffer(torch.zeros(NEW_IMAGE.shape[1:], dtype=torch.uint8))
452
487
  # Enjoy!
453
488
  ```
454
489
 
490
+ ## Running inference on a remote GPU (client / server)
491
+
492
+ If the machine running your GUI does not have a powerful GPU, you can run the
493
+ model on a remote box and drive it over HTTP with
494
+ **`nnInteractiveRemoteInferenceSession`** — a drop-in replacement with the same
495
+ public API as the local session. The server loads the model once at startup and
496
+ hosts multiple concurrent client sessions; each client keeps its own image,
497
+ target buffer, and interaction state.
498
+
499
+ Start the server on the GPU box:
500
+
501
+ ```bash
502
+ nninteractive-server \
503
+ --model-dir /path/to/checkpoint_folder --fold all \
504
+ --host 0.0.0.0 --port 1527 \
505
+ --api-key "$(openssl rand -hex 32)"
506
+ ```
507
+
508
+ And in the client code, swap the local session for the remote one:
509
+
510
+ ```python
511
+ from nnInteractive.inference.remote import nnInteractiveRemoteInferenceSession
512
+
513
+ session = nnInteractiveRemoteInferenceSession(
514
+ server_url="http://gpu-box.lab:1527",
515
+ api_key="…",
516
+ )
517
+ # From here on, the API is identical to nnInteractiveInferenceSession.
518
+ ```
519
+
520
+ For full details — installation, authentication, single-user SSH-tunnel setup,
521
+ multi-user deployment behind a reverse proxy, concurrency/session model, idle
522
+ expiry and heartbeats, GUI integration notes, and troubleshooting — see
523
+ [`SERVER_CLIENT.md`](SERVER_CLIENT.md).
524
+
455
525
  ## nnInteractive SuperVoxels
456
526
 
457
527
  As part of the `nnInteractive` framework, we provide a dedicated module for **supervoxel generation** based on [SAM](https://github.com/facebookresearch/segment-anything) and [SAM2](https://github.com/facebookresearch/sam2). This replaces traditional superpixel methods (e.g., SLIC) with **foundation model–powered 3D pseudo-labels**.
@@ -50,14 +50,16 @@ class nnInteractiveInferenceSession:
50
50
  ):
51
51
  """
52
52
  Only intended to work with nnInteractiveTrainerV2 and its derivatives
53
+
54
+ ``use_torch_compile``: compile the network with ``torch.compile``. The
55
+ first prediction after enabling this is slow (compilation happens lazily
56
+ on the first forward pass), but every subsequent prediction is faster.
57
+ This is recommended for the persistent inference server, where the
58
+ process is long-lived so the one-time compile cost is paid only once and
59
+ amortized across the whole session lifetime.
53
60
  """
54
61
  print("session initialized")
55
62
 
56
- # set as part of initialization
57
- assert use_torch_compile is False, (
58
- "torch.compile is not supported. The blosc2-backed interaction tensor "
59
- "requires numpy↔torch round-trips that break compile tracing."
60
- )
61
63
  self.network = None
62
64
  self.label_manager = None
63
65
  self.dataset_json = None
@@ -83,6 +85,10 @@ class nnInteractiveInferenceSession:
83
85
  self.preprocessed_image: torch.Tensor = None
84
86
  self.preprocessed_props = None
85
87
  self.target_buffer: Union[np.ndarray, torch.Tensor] = None
88
+ # Bbox (in original-image coordinates) of the most recent target_buffer write.
89
+ # Captured inside _paste_prediction_to_target_buffer so remote callers can
90
+ # fetch just the touched region without diffing.
91
+ self._last_paste_bbox: Optional[List[List[int]]] = None
86
92
 
87
93
  # this will be set when loading the model (initialize_from_trained_model_folder)
88
94
  self.pad_mode_data = self.preferred_scribble_thickness = self.point_interaction = None
@@ -287,6 +293,7 @@ class nnInteractiveInferenceSession:
287
293
  else:
288
294
  pred_for_target = prediction.to("cpu")
289
295
  paste_tensor(self.target_buffer, pred_for_target, target_bbox)
296
+ self._last_paste_bbox = target_bbox
290
297
 
291
298
  def _estimate_refinement_cache_nbytes(self, cache_bbox: List[List[int]]) -> int:
292
299
  cache_voxels = int(np.prod(self._bbox_size(cache_bbox), dtype=np.int64))
@@ -517,6 +524,7 @@ class nnInteractiveInferenceSession:
517
524
  self.current_interaction_intensity = 1.0
518
525
  empty_cache(self.device)
519
526
  self.original_image_shape = None
527
+ self._last_paste_bbox = None
520
528
 
521
529
  def _initialize_interactions(self, image_torch: torch.Tensor):
522
530
  shape = (self.num_interaction_channels, *image_torch.shape[1:])
@@ -606,6 +614,7 @@ class nnInteractiveInferenceSession:
606
614
  self.target_buffer.fill(0)
607
615
  elif isinstance(self.target_buffer, torch.Tensor):
608
616
  self.target_buffer.zero_()
617
+ self._last_paste_bbox = None
609
618
  empty_cache(self.device)
610
619
 
611
620
  def add_bbox_interaction(
@@ -876,6 +885,42 @@ class nnInteractiveInferenceSession:
876
885
  else:
877
886
  del initial_seg
878
887
 
888
+ @torch.inference_mode()
889
+ def warmup(self) -> bool:
890
+ """Run a single dummy forward pass to trigger lazy ``torch.compile`` compilation up front.
891
+
892
+ With ``torch.compile`` enabled the network is compiled lazily on its first
893
+ forward pass, which would otherwise make the user's *first* real prediction
894
+ slow. Every prediction path — the initial coarse pass, the zoom-out
895
+ iterations, and the refinement patches — feeds the network an input of
896
+ identical shape ``[1, num_input_channels + num_interaction_channels,
897
+ *patch_size]`` (``_build_network_input`` always resizes the crop to
898
+ ``patch_size``, and refinement crops at exactly ``patch_size``). So a single
899
+ dummy pass at that shape populates the compile cache and every subsequent
900
+ real prediction is fast.
901
+
902
+ Returns ``True`` if a warmup pass was run, ``False`` if it was a no-op
903
+ (network not compiled — there is nothing to pre-compile, so a dummy pass
904
+ would not save the user any time). Mirrors ``_predict``'s autocast/
905
+ inference-mode context and the float32 input dtype that ``torch.cat``
906
+ produces when concatenating the float32 image with the fp16 interactions.
907
+ """
908
+ if self.network is None or self.configuration_manager is None:
909
+ raise RuntimeError("warmup() requires an initialized network; call initialize_* first")
910
+ if not isinstance(self.network, OptimizedModule):
911
+ return False
912
+ num_input_channels = (
913
+ determine_num_input_channels(self.plans_manager, self.configuration_manager, self.dataset_json)
914
+ + self.num_interaction_channels
915
+ )
916
+ patch_size = self.configuration_manager.patch_size
917
+ dummy = torch.zeros((1, num_input_channels, *patch_size), dtype=torch.float32, device=self.device)
918
+ with torch.autocast(self.device.type, enabled=True) if self.device.type == "cuda" else dummy_context():
919
+ self.network(dummy)
920
+ del dummy
921
+ empty_cache(self.device)
922
+ return True
923
+
879
924
  @torch.inference_mode()
880
925
  def _predict(self, force_full_refine: bool = False):
881
926
  """
@@ -1287,6 +1332,36 @@ class nnInteractiveInferenceSession:
1287
1332
  """
1288
1333
  This is used when making predictions with a trained model
1289
1334
  """
1335
+ artifacts = self._load_model_artifacts_from_disk(model_training_output_dir, use_fold, checkpoint_name)
1336
+ self.initialize_from_loaded_artifacts(artifacts)
1337
+
1338
+ def _load_model_artifacts_from_disk(
1339
+ self,
1340
+ model_training_output_dir: str,
1341
+ use_fold: Union[int, str] = None,
1342
+ checkpoint_name: str = "checkpoint_final.pth",
1343
+ ) -> dict:
1344
+ """Read all model artifacts from disk and build the network on ``self.device``.
1345
+
1346
+ Returns an artifact dict that can be applied to this or any other freshly
1347
+ constructed session via :meth:`initialize_from_loaded_artifacts`. The
1348
+ returned values are the actual objects (the ``nn.Module`` with its
1349
+ weights and buffers, the plans/configuration managers, the dataset
1350
+ json, the label manager) — not copies. Multiple sessions calling
1351
+ :meth:`initialize_from_loaded_artifacts` with the same dict will all
1352
+ end up with ``self.network`` pointing at the same module instance and
1353
+ the same weight tensors on the GPU. This is safe as long as callers
1354
+ treat these objects as read-only after construction; in the multi-
1355
+ session server that is enforced by running inference under
1356
+ ``@torch.inference_mode()`` and serializing predict calls with a
1357
+ global GPU lock.
1358
+
1359
+ Note: this also mutates ``self`` (applies capability, sets pad/decay/
1360
+ thickness) because ``num_interaction_channels`` is required to build the
1361
+ network. The caller should follow up with
1362
+ :meth:`initialize_from_loaded_artifacts` (this is what
1363
+ :meth:`initialize_from_trained_model_folder` does).
1364
+ """
1290
1365
  point_interaction_use_etd = True
1291
1366
  (
1292
1367
  capability_content,
@@ -1353,12 +1428,41 @@ class nnInteractiveInferenceSession:
1353
1428
  ).to(self.device)
1354
1429
  network.load_state_dict(parameters)
1355
1430
 
1356
- self.plans_manager = plans_manager
1357
- self.configuration_manager = configuration_manager
1358
- self.network = network
1359
- self.dataset_json = dataset_json
1360
- self.trainer_name = trainer_name
1361
- self.label_manager = plans_manager.get_label_manager(dataset_json)
1431
+ return {
1432
+ "capability_content": capability_content,
1433
+ "point_interaction": self.point_interaction,
1434
+ "preferred_scribble_thickness": self.preferred_scribble_thickness,
1435
+ "interaction_decay": self.interaction_decay,
1436
+ "pad_mode_data": self.pad_mode_data,
1437
+ "network": network,
1438
+ "plans_manager": plans_manager,
1439
+ "configuration_manager": configuration_manager,
1440
+ "dataset_json": dataset_json,
1441
+ "trainer_name": trainer_name,
1442
+ "label_manager": plans_manager.get_label_manager(dataset_json),
1443
+ }
1444
+
1445
+ def initialize_from_loaded_artifacts(self, artifacts: dict):
1446
+ """Apply pre-loaded artifacts to this session instance.
1447
+
1448
+ ``artifacts`` is the dict returned by :meth:`_load_model_artifacts_from_disk`.
1449
+ Useful for spawning multiple sessions that share one loaded model (e.g.
1450
+ the multi-session inference server). All artifact entries — including
1451
+ ``self.network`` — are stored by reference; passing the same dict to
1452
+ multiple sessions does not duplicate the network or its weights in
1453
+ memory.
1454
+ """
1455
+ self.preferred_scribble_thickness = artifacts["preferred_scribble_thickness"]
1456
+ self.interaction_decay = artifacts["interaction_decay"]
1457
+ self.pad_mode_data = artifacts["pad_mode_data"]
1458
+ self.point_interaction = artifacts["point_interaction"]
1459
+ self._apply_capability(artifacts["capability_content"])
1460
+ self.plans_manager = artifacts["plans_manager"]
1461
+ self.configuration_manager = artifacts["configuration_manager"]
1462
+ self.network = artifacts["network"]
1463
+ self.dataset_json = artifacts["dataset_json"]
1464
+ self.trainer_name = artifacts["trainer_name"]
1465
+ self.label_manager = artifacts["label_manager"]
1362
1466
  if self.use_torch_compile and not isinstance(self.network, OptimizedModule):
1363
1467
  print("Using torch.compile")
1364
1468
  self.network = torch.compile(self.network)
@@ -0,0 +1,11 @@
1
+ from nnInteractive.inference.remote.remote_session import (
2
+ ServerAtCapacityError,
3
+ SessionExpiredError,
4
+ nnInteractiveRemoteInferenceSession,
5
+ )
6
+
7
+ __all__ = [
8
+ "nnInteractiveRemoteInferenceSession",
9
+ "SessionExpiredError",
10
+ "ServerAtCapacityError",
11
+ ]
@@ -0,0 +1,27 @@
1
+ """Shared constants for the nnInteractive client/server HTTP protocol."""
2
+
3
+ # HTTP header used to carry JSON-encoded metadata alongside a binary array body.
4
+ META_HEADER = "X-Meta"
5
+ # HTTP header used to carry a per-client lease token identifying which session
6
+ # on the (multi-session) server the request applies to.
7
+ LEASE_HEADER = "X-Lease-Token"
8
+
9
+ # Endpoint paths.
10
+ PATH_HEALTHZ = "/healthz"
11
+ PATH_CAPABILITIES = "/capabilities"
12
+ PATH_CLAIM = "/claim"
13
+ PATH_RELEASE = "/release"
14
+ PATH_HEARTBEAT = "/heartbeat"
15
+ PATH_LEASE_STATUS = "/lease_status"
16
+ PATH_SET_IMAGE = "/set_image"
17
+ PATH_SET_TARGET_BUFFER = "/set_target_buffer"
18
+ PATH_RESET_INTERACTIONS = "/reset_interactions"
19
+ PATH_SET_DO_AUTOZOOM = "/set_do_autozoom"
20
+ PATH_ADD_BBOX = "/add_bbox_interaction"
21
+ PATH_ADD_POINT = "/add_point_interaction"
22
+ PATH_ADD_SCRIBBLE = "/add_scribble_interaction"
23
+ PATH_ADD_LASSO = "/add_lasso_interaction"
24
+ PATH_ADD_INITIAL_SEG = "/add_initial_seg_interaction"
25
+
26
+ # Body content type for endpoints that ship a packed numpy array.
27
+ CONTENT_TYPE_OCTET_STREAM = "application/octet-stream"