nnInteractive 2.1.0__tar.gz → 2.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. {nninteractive-2.1.0 → nninteractive-2.3.0}/PKG-INFO +7 -1
  2. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/inference_session.py +118 -17
  3. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/remote/remote_session.py +73 -7
  4. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/remote/serialization.py +65 -3
  5. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/server/app.py +160 -33
  6. nninteractive-2.3.0/nnInteractive/inference/server/main.py +233 -0
  7. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive.egg-info/PKG-INFO +7 -1
  8. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive.egg-info/SOURCES.txt +0 -0
  9. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive.egg-info/dependency_links.txt +0 -0
  10. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive.egg-info/requires.txt +5 -0
  11. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive.egg-info/top_level.txt +0 -0
  12. {nninteractive-2.1.0 → nninteractive-2.3.0}/pyproject.toml +10 -1
  13. {nninteractive-2.1.0 → nninteractive-2.3.0}/readme.md +2 -0
  14. nninteractive-2.1.0/nnInteractive/inference/server/main.py +0 -149
  15. {nninteractive-2.1.0 → nninteractive-2.3.0}/LICENSE +0 -0
  16. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/__init__.py +0 -0
  17. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/__init__.py +0 -0
  18. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
  19. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +0 -0
  20. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/remote/__init__.py +0 -0
  21. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/remote/_protocol.py +0 -0
  22. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/server/__init__.py +0 -0
  23. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/interaction/__init__.py +0 -0
  24. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/interaction/point.py +0 -0
  25. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/setup.py +0 -0
  26. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/metadata.py +0 -0
  27. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/reader.py +0 -0
  28. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/run.py +0 -0
  29. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/__init__.py +0 -0
  30. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/__init__.py +0 -0
  31. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +0 -0
  32. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +0 -0
  33. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +0 -0
  34. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +0 -0
  35. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +0 -0
  36. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +0 -0
  37. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +0 -0
  38. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +0 -0
  39. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +0 -0
  40. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +0 -0
  41. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +0 -0
  42. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +0 -0
  43. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +0 -0
  44. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +0 -0
  45. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +0 -0
  46. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +0 -0
  47. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +0 -0
  48. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +0 -0
  49. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +0 -0
  50. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +0 -0
  51. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +0 -0
  52. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +0 -0
  53. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +0 -0
  54. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +0 -0
  55. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/setup.py +0 -0
  56. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/__init__.py +0 -0
  57. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +0 -0
  58. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +0 -0
  59. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +0 -0
  60. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +0 -0
  61. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +0 -0
  62. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +0 -0
  63. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +0 -0
  64. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +0 -0
  65. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/loss_fns.py +0 -0
  66. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/model/__init__.py +0 -0
  67. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/model/sam2.py +0 -0
  68. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/optimizer.py +0 -0
  69. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +0 -0
  70. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/train.py +0 -0
  71. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/trainer.py +0 -0
  72. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +0 -0
  73. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +0 -0
  74. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +0 -0
  75. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +0 -0
  76. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/utils/logger.py +0 -0
  77. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +0 -0
  78. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/supervoxel.py +0 -0
  79. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/trainer/__init__.py +0 -0
  80. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/trainer/nnInteractiveTrainer.py +0 -0
  81. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/utils/__init__.py +0 -0
  82. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/utils/bboxes.py +0 -0
  83. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/utils/checkpoint_cleansing.py +0 -0
  84. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/utils/crop.py +0 -0
  85. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/utils/erosion_dilation.py +0 -0
  86. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/utils/inference_helpers.py +0 -0
  87. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/utils/os_shennanigans.py +0 -0
  88. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/utils/rounding.py +0 -0
  89. {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive.egg-info/entry_points.txt +0 -0
  90. {nninteractive-2.1.0 → nninteractive-2.3.0}/setup.cfg +0 -0
  91. {nninteractive-2.1.0 → nninteractive-2.3.0}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nnInteractive
3
- Version: 2.1.0
3
+ Version: 2.3.0
4
4
  Summary: Inference code for nnInteractive
5
5
  Author: Helmholtz Imaging Applied Computer Vision Lab
6
6
  Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
@@ -226,10 +226,14 @@ Requires-Dist: batchgenerators>=0.25.1
226
226
  Requires-Dist: fastapi>=0.110
227
227
  Requires-Dist: uvicorn[standard]>=0.27
228
228
  Requires-Dist: httpx>=0.27
229
+ Requires-Dist: blosc2
229
230
  Provides-Extra: dev
230
231
  Requires-Dist: black; extra == "dev"
231
232
  Requires-Dist: ruff; extra == "dev"
232
233
  Requires-Dist: pre-commit; extra == "dev"
234
+ Provides-Extra: client
235
+ Requires-Dist: httpx>=0.27; extra == "client"
236
+ Requires-Dist: blosc2; extra == "client"
233
237
  Dynamic: license-file
234
238
 
235
239
  <img src="imgs/nnInteractive_header_white.png">
@@ -545,6 +549,8 @@ Link: [![arXiv](https://img.shields.io/badge/arXiv-2503.08373-b31b1b.svg)](https
545
549
  # License
546
550
  Note that while this repository is available under Apache-2.0 license (see [LICENSE](./LICENSE)), the [model checkpoint](https://huggingface.co/nnInteractive/nnInteractive) is `Creative Commons Attribution Non Commercial Share Alike 4.0`!
547
551
 
552
+ Release model folders ship their own `LICENSE` file whose **first line is the license identifier** (e.g. `CC BY-NC-SA 4.0`); any following lines (such as a link to the full license) are ignored by the tool. At load time this first line is read and exposed as `session.license` so applications can display the model's license prominently. If a checkpoint folder has no `LICENSE` file, the official v1 checkpoint is assumed to be `CC BY-NC-SA 4.0` and any other checkpoint reports `!!MISSING!!`.
553
+
548
554
  # Changelog
549
555
 
550
556
  ### 1.1.2 - 2025-08-02
@@ -50,14 +50,16 @@ class nnInteractiveInferenceSession:
50
50
  ):
51
51
  """
52
52
  Only intended to work with nnInteractiveTrainerV2 and its derivatives
53
+
54
+ ``use_torch_compile``: compile the network with ``torch.compile``. The
55
+ first prediction after enabling this is slow (compilation happens lazily
56
+ on the first forward pass), but every subsequent prediction is faster.
57
+ This is recommended for the persistent inference server, where the
58
+ process is long-lived so the one-time compile cost is paid only once and
59
+ amortized across the whole session lifetime.
53
60
  """
54
61
  print("session initialized")
55
62
 
56
- # set as part of initialization
57
- assert use_torch_compile is False, (
58
- "torch.compile is not supported. The blosc2-backed interaction tensor "
59
- "requires numpy↔torch round-trips that break compile tracing."
60
- )
61
63
  self.network = None
62
64
  self.label_manager = None
63
65
  self.dataset_json = None
@@ -77,6 +79,11 @@ class nnInteractiveInferenceSession:
77
79
  self.channel_mapping: dict = {}
78
80
  self.supports_initial_label: bool = True
79
81
  self.supports_zero_shot_label_refinement: bool = True
82
+ # License of the loaded model checkpoint. Set when the model is loaded
83
+ # (read from the LICENSE file in the checkpoint folder, or derived for
84
+ # legacy checkpoints without one). Exposed so GUIs can display it once
85
+ # the session is initialized. "!!MISSING!!" means the license is unknown.
86
+ self.license: Optional[str] = None
80
87
 
81
88
  # image specific
82
89
  self.interactions = None # blosc2.NDArray once initialized
@@ -116,6 +123,31 @@ class nnInteractiveInferenceSession:
116
123
  and checkpoint.get("init_args", {}).get("configuration") == "3d_fullres_ps192_bs24"
117
124
  )
118
125
 
126
+ @classmethod
127
+ def _load_license(cls, model_training_output_dir: str, plans: dict, checkpoint: dict) -> str:
128
+ """Determine the license of the model being loaded.
129
+
130
+ Reads the ``LICENSE`` file from the checkpoint folder if present.
131
+ Expected format: the FIRST line is a short license identifier (e.g.
132
+ ``CC BY-NC-SA 4.0``); any following lines (URL, full text, …) are for
133
+ human readers and are ignored. Only the first non-empty line is
134
+ returned, so ``self.license`` stays a short, displayable string.
135
+
136
+ If the folder has no ``LICENSE`` file it is most likely a legacy model:
137
+ the official v1 checkpoint is CC BY-NC-SA 4.0, anything else is reported
138
+ as ``"!!MISSING!!"`` so callers (e.g. GUIs) can flag the unknown license.
139
+ """
140
+ license_file = join(model_training_output_dir, "LICENSE")
141
+ if isfile(license_file):
142
+ with open(license_file, "r", encoding="utf-8") as f:
143
+ for line in f:
144
+ line = line.strip()
145
+ if line:
146
+ return line
147
+ if cls._is_official_checkpoint(plans, checkpoint):
148
+ return "CC BY-NC-SA 4.0"
149
+ return "!!MISSING!!"
150
+
119
151
  def _legacy_default_capability(self) -> dict:
120
152
  return {
121
153
  "supported_interactions": {
@@ -533,7 +565,13 @@ class nnInteractiveInferenceSession:
533
565
  dtype=np.float16,
534
566
  chunks=(1, *[min(64, s) for s in shape[1:]]),
535
567
  blocks=(1, *[min(32, s) for s in shape[1:]]),
536
- cparams={"codec": blosc2.Codec.LZ4, "clevel": 5, "nthreads": min(self.torch_n_threads, os.cpu_count())},
568
+ # Interactions compress better with NOFILTER, which is also faster than SHUFFLE.
569
+ cparams={
570
+ "codec": blosc2.Codec.LZ4,
571
+ "clevel": 5,
572
+ "filters": [blosc2.Filter.NOFILTER],
573
+ "nthreads": min(self.torch_n_threads, os.cpu_count()),
574
+ },
537
575
  dparams={"nthreads": 4},
538
576
  )
539
577
  self._interactions_shape = shape
@@ -602,7 +640,13 @@ class nnInteractiveInferenceSession:
602
640
  dtype=np.float16,
603
641
  chunks=(1, *[min(64, s) for s in self._interactions_shape[1:]]),
604
642
  blocks=(1, *[min(32, s) for s in self._interactions_shape[1:]]),
605
- cparams={"codec": blosc2.Codec.LZ4, "clevel": 5, "nthreads": os.cpu_count()},
643
+ # Interactions compress better with NOFILTER, which is also faster than SHUFFLE.
644
+ cparams={
645
+ "codec": blosc2.Codec.LZ4,
646
+ "clevel": 5,
647
+ "filters": [blosc2.Filter.NOFILTER],
648
+ "nthreads": os.cpu_count(),
649
+ },
606
650
  dparams={"nthreads": 4},
607
651
  )
608
652
  self.current_interaction_intensity = 1.0
@@ -883,6 +927,42 @@ class nnInteractiveInferenceSession:
883
927
  else:
884
928
  del initial_seg
885
929
 
930
+ @torch.inference_mode()
931
+ def warmup(self) -> bool:
932
+ """Run a single dummy forward pass to trigger lazy ``torch.compile`` compilation up front.
933
+
934
+ With ``torch.compile`` enabled the network is compiled lazily on its first
935
+ forward pass, which would otherwise make the user's *first* real prediction
936
+ slow. Every prediction path — the initial coarse pass, the zoom-out
937
+ iterations, and the refinement patches — feeds the network an input of
938
+ identical shape ``[1, num_input_channels + num_interaction_channels,
939
+ *patch_size]`` (``_build_network_input`` always resizes the crop to
940
+ ``patch_size``, and refinement crops at exactly ``patch_size``). So a single
941
+ dummy pass at that shape populates the compile cache and every subsequent
942
+ real prediction is fast.
943
+
944
+ Returns ``True`` if a warmup pass was run, ``False`` if it was a no-op
945
+ (network not compiled — there is nothing to pre-compile, so a dummy pass
946
+ would not save the user any time). Mirrors ``_predict``'s autocast/
947
+ inference-mode context and the float32 input dtype that ``torch.cat``
948
+ produces when concatenating the float32 image with the fp16 interactions.
949
+ """
950
+ if self.network is None or self.configuration_manager is None:
951
+ raise RuntimeError("warmup() requires an initialized network; call initialize_* first")
952
+ if not isinstance(self.network, OptimizedModule):
953
+ return False
954
+ num_input_channels = (
955
+ determine_num_input_channels(self.plans_manager, self.configuration_manager, self.dataset_json)
956
+ + self.num_interaction_channels
957
+ )
958
+ patch_size = self.configuration_manager.patch_size
959
+ dummy = torch.zeros((1, num_input_channels, *patch_size), dtype=torch.float32, device=self.device)
960
+ with torch.autocast(self.device.type, enabled=True) if self.device.type == "cuda" else dummy_context():
961
+ self.network(dummy)
962
+ del dummy
963
+ empty_cache(self.device)
964
+ return True
965
+
886
966
  @torch.inference_mode()
887
967
  def _predict(self, force_full_refine: bool = False):
888
968
  """
@@ -1296,6 +1376,16 @@ class nnInteractiveInferenceSession:
1296
1376
  """
1297
1377
  artifacts = self._load_model_artifacts_from_disk(model_training_output_dir, use_fold, checkpoint_name)
1298
1378
  self.initialize_from_loaded_artifacts(artifacts)
1379
+ # With torch.compile the network is compiled lazily on the first forward pass. For a
1380
+ # locally hosted model that lag would otherwise surface on the user's first real
1381
+ # prediction, where it is far more noticeable than during initialization. Trigger the
1382
+ # compilation now with a dummy forward pass so the cost is paid here instead. warmup()
1383
+ # is a no-op when the network is not compiled. The server takes care of its own warmup
1384
+ # explicitly (it shares one compiled network across sessions via
1385
+ # initialize_from_loaded_artifacts), so we only do this on the direct, local entry point.
1386
+ if self.use_torch_compile:
1387
+ print("torch.compile enabled; warming up (compiling) the network now (this is slow once)...")
1388
+ self.warmup()
1299
1389
 
1300
1390
  def _load_model_artifacts_from_disk(
1301
1391
  self,
@@ -1307,9 +1397,16 @@ class nnInteractiveInferenceSession:
1307
1397
 
1308
1398
  Returns an artifact dict that can be applied to this or any other freshly
1309
1399
  constructed session via :meth:`initialize_from_loaded_artifacts`. The
1310
- returned references (network, plans_manager, configuration_manager, ...)
1311
- are safe to share by reference across sessions — they are treated as
1312
- read-only after construction.
1400
+ returned values are the actual objects (the ``nn.Module`` with its
1401
+ weights and buffers, the plans/configuration managers, the dataset
1402
+ json, the label manager) — not copies. Multiple sessions calling
1403
+ :meth:`initialize_from_loaded_artifacts` with the same dict will all
1404
+ end up with ``self.network`` pointing at the same module instance and
1405
+ the same weight tensors on the GPU. This is safe as long as callers
1406
+ treat these objects as read-only after construction; in the multi-
1407
+ session server that is enforced by running inference under
1408
+ ``@torch.inference_mode()`` and serializing predict calls with a
1409
+ global GPU lock.
1313
1410
 
1314
1411
  Note: this also mutates ``self`` (applies capability, sets pad/decay/
1315
1412
  thickness) because ``num_interaction_channels`` is required to build the
@@ -1344,12 +1441,11 @@ class nnInteractiveInferenceSession:
1344
1441
  checkpoint = torch.load(
1345
1442
  join(model_training_output_dir, fold_folder, checkpoint_name), map_location=self.device, weights_only=False
1346
1443
  )
1347
- if self._is_official_checkpoint(plans, checkpoint):
1348
- print(
1349
- "License reminder: The official nnInteractive checkpoint is licensed under "
1350
- "Creative Commons Attribution Non Commercial Share Alike 4.0 (CC BY-NC-SA 4.0). "
1351
- "See the license note in readme.md (# License)."
1352
- )
1444
+ self.license = self._load_license(model_training_output_dir, plans, checkpoint)
1445
+ print("=" * 80)
1446
+ print("Model license:")
1447
+ print(self.license)
1448
+ print("=" * 80)
1353
1449
  trainer_name = checkpoint["trainer_name"]
1354
1450
  configuration_name = checkpoint["init_args"]["configuration"]
1355
1451
 
@@ -1395,6 +1491,7 @@ class nnInteractiveInferenceSession:
1395
1491
  "dataset_json": dataset_json,
1396
1492
  "trainer_name": trainer_name,
1397
1493
  "label_manager": plans_manager.get_label_manager(dataset_json),
1494
+ "license": self.license,
1398
1495
  }
1399
1496
 
1400
1497
  def initialize_from_loaded_artifacts(self, artifacts: dict):
@@ -1402,7 +1499,10 @@ class nnInteractiveInferenceSession:
1402
1499
 
1403
1500
  ``artifacts`` is the dict returned by :meth:`_load_model_artifacts_from_disk`.
1404
1501
  Useful for spawning multiple sessions that share one loaded model (e.g.
1405
- the multi-session inference server).
1502
+ the multi-session inference server). All artifact entries — including
1503
+ ``self.network`` — are stored by reference; passing the same dict to
1504
+ multiple sessions does not duplicate the network or its weights in
1505
+ memory.
1406
1506
  """
1407
1507
  self.preferred_scribble_thickness = artifacts["preferred_scribble_thickness"]
1408
1508
  self.interaction_decay = artifacts["interaction_decay"]
@@ -1415,6 +1515,7 @@ class nnInteractiveInferenceSession:
1415
1515
  self.dataset_json = artifacts["dataset_json"]
1416
1516
  self.trainer_name = artifacts["trainer_name"]
1417
1517
  self.label_manager = artifacts["label_manager"]
1518
+ self.license = artifacts["license"]
1418
1519
  if self.use_torch_compile and not isinstance(self.network, OptimizedModule):
1419
1520
  print("Using torch.compile")
1420
1521
  self.network = torch.compile(self.network)
@@ -10,9 +10,11 @@ from __future__ import annotations
10
10
 
11
11
  import json
12
12
  import os
13
+ import threading
13
14
  import warnings
14
15
  from typing import List, Optional, Tuple, Union
15
16
 
17
+ import blosc2
16
18
  import httpx
17
19
  import numpy as np
18
20
  import torch
@@ -165,8 +167,16 @@ class nnInteractiveRemoteInferenceSession:
165
167
  claim_info = claim_resp.json()
166
168
  self._lease_token = claim_info["lease_token"]
167
169
  self.idle_timeout_seconds: float = float(claim_info.get("idle_timeout_seconds", 0.0))
170
+ self.liveness_timeout_seconds: float = float(claim_info.get("liveness_timeout_seconds", 0.0))
168
171
  self._http.headers[LEASE_HEADER] = self._lease_token
169
172
 
173
+ # Background liveness heartbeat bookkeeping. Defined before any code that
174
+ # might raise so close()/__del__ can always reference them safely. The
175
+ # thread itself is started at the end of __init__, once construction has
176
+ # fully succeeded.
177
+ self._stop_heartbeat = threading.Event()
178
+ self._heartbeat_thread: Optional[threading.Thread] = None
179
+
170
180
  caps = self._get_json(PATH_CAPABILITIES)
171
181
 
172
182
  # Attributes that mirror the local session so the GUI can introspect them
@@ -182,11 +192,27 @@ class nnInteractiveRemoteInferenceSession:
182
192
  self.preferred_scribble_thickness = caps["preferred_scribble_thickness"]
183
193
  self.interaction_decay = caps["interaction_decay"]
184
194
  self.INFERENCE_SESSION_VERSION = caps["inference_session_version"]
195
+ # License of the model loaded on the server. Mirrors
196
+ # nnInteractiveInferenceSession.license so a GUI can display it
197
+ # regardless of whether it holds a local or remote session.
198
+ # "!!MISSING!!" means the server could not determine the license.
199
+ self.license: Optional[str] = caps.get("license")
185
200
 
186
201
  self.original_image_shape: Optional[Tuple[int, ...]] = None
187
202
  self.target_buffer: Union[np.ndarray, torch.Tensor, None] = None
188
203
  self.do_autozoom: bool = bool(caps.get("do_autozoom", True))
189
204
 
205
+ # Construction succeeded — start auto-heartbeating to keep the server
206
+ # from reaping us as a dead client. Beat at half the liveness timeout so
207
+ # one dropped request still leaves margin. Daemon thread: it never blocks
208
+ # interpreter exit, and close() joins it cleanly.
209
+ if self.liveness_timeout_seconds > 0:
210
+ self._heartbeat_interval = max(5.0, self.liveness_timeout_seconds / 2.0)
211
+ self._heartbeat_thread = threading.Thread(
212
+ target=self._heartbeat_loop, name="nnInteractive-heartbeat", daemon=True
213
+ )
214
+ self._heartbeat_thread.start()
215
+
190
216
  # ------------------------------------------------------------------ #
191
217
  # HTTP helpers (private) #
192
218
  # ------------------------------------------------------------------ #
@@ -272,7 +298,10 @@ class nnInteractiveRemoteInferenceSession:
272
298
  self.target_buffer = target_buffer
273
299
  self._post_json(
274
300
  PATH_SET_TARGET_BUFFER,
275
- {"shape": list(target_buffer.shape), "dtype": _buffer_dtype_str(target_buffer)},
301
+ {
302
+ "shape": list(target_buffer.shape),
303
+ "dtype": _buffer_dtype_str(target_buffer),
304
+ },
276
305
  )
277
306
 
278
307
  def set_do_autozoom(self, do_autozoom: bool) -> None:
@@ -372,7 +401,8 @@ class nnInteractiveRemoteInferenceSession:
372
401
  "override_capability_checks": bool(override_capability_checks),
373
402
  "interaction_bbox": ([list(b) for b in interaction_bbox] if interaction_bbox is not None else None),
374
403
  }
375
- resp = self._post_binary(path, meta, pack_array(mask_image))
404
+ # Interactions (scribble/lasso masks) compress best with NOFILTER; skip auto-selection.
405
+ resp = self._post_binary(path, meta, pack_array(mask_image, filters=[blosc2.Filter.NOFILTER]))
376
406
  self._apply_prediction_response(resp)
377
407
 
378
408
  def add_initial_seg_interaction(
@@ -397,7 +427,8 @@ class nnInteractiveRemoteInferenceSession:
397
427
  "run_prediction": bool(run_prediction),
398
428
  "override_capability_checks": bool(override_capability_checks),
399
429
  }
400
- resp = self._post_binary(PATH_ADD_INITIAL_SEG, meta, pack_array(initial_seg))
430
+ # Segmentations compress best with NOFILTER; skip auto-selection.
431
+ resp = self._post_binary(PATH_ADD_INITIAL_SEG, meta, pack_array(initial_seg, filters=[blosc2.Filter.NOFILTER]))
401
432
  self._apply_prediction_response(resp)
402
433
 
403
434
  # ------------------------------------------------------------------ #
@@ -420,17 +451,44 @@ class nnInteractiveRemoteInferenceSession:
420
451
  return False
421
452
 
422
453
  def heartbeat(self) -> float:
423
- """Extend this session's idle timeout. Returns remaining seconds.
454
+ """Tell the server this client is still alive. Returns remaining seconds
455
+ until the *idle* timeout.
456
+
457
+ This proves liveness only: it stops the server from reaping the session
458
+ as a crashed/dead client, but it does NOT postpone the idle timeout —
459
+ that is refreshed solely by real user actions (``set_image``,
460
+ ``add_*_interaction``, …). A session left untouched will therefore still
461
+ be reaped at the idle timeout even while heartbeats keep arriving.
424
462
 
425
- For GUIs that keep the app open across long idle stretches: drive
426
- this from a timer (e.g. every 60 s) to avoid SessionExpiredError
427
- after the configured idle timeout.
463
+ You normally never call this yourself: the session auto-heartbeats from
464
+ a background thread for the lifetime of the object.
428
465
  """
429
466
  resp = self._http.post(PATH_HEARTBEAT)
430
467
  _raise_for_lease_errors(resp)
431
468
  resp.raise_for_status()
432
469
  return float(resp.json().get("remaining_seconds", 0.0))
433
470
 
471
+ def _heartbeat_loop(self) -> None:
472
+ """Background daemon: prove liveness every ``_heartbeat_interval`` seconds.
473
+
474
+ Stops when the session is closed or once the lease is gone. Transient
475
+ network errors are swallowed so a brief blip doesn't kill the heartbeat;
476
+ the server's liveness timeout tolerates a few missed beats. Lease expiry
477
+ (idle reap or server restart) is surfaced to the user on their next real
478
+ call, not from this thread.
479
+ """
480
+ while not self._stop_heartbeat.wait(self._heartbeat_interval):
481
+ try:
482
+ self.heartbeat()
483
+ except SessionExpiredError:
484
+ break
485
+ except httpx.HTTPError:
486
+ continue
487
+ except Exception:
488
+ # Never let the daemon thread die noisily (e.g. client closing
489
+ # concurrently). Bail out quietly.
490
+ break
491
+
434
492
  def lease_status(self) -> dict:
435
493
  """Read-only probe: how much time is left before this session is reaped?
436
494
 
@@ -444,6 +502,14 @@ class nnInteractiveRemoteInferenceSession:
444
502
  return resp.json()
445
503
 
446
504
  def close(self) -> None:
505
+ # Stop the heartbeat thread first so it can't use self._http after we
506
+ # close it. join() with a short timeout: the thread spends almost all
507
+ # its time in Event.wait(), which the set() interrupts immediately.
508
+ self._stop_heartbeat.set()
509
+ if self._heartbeat_thread is not None:
510
+ self._heartbeat_thread.join(timeout=5.0)
511
+ self._heartbeat_thread = None
512
+
447
513
  # Best-effort release so the server can free our slot for other users
448
514
  # without waiting for the idle reaper. Swallow errors: the server may
449
515
  # already be gone, our lease may already be expired, etc. close()
@@ -47,8 +47,64 @@ _CODEC_ID = {
47
47
  _ID_CODEC = {v: k for k, v in _CODEC_ID.items()}
48
48
 
49
49
 
50
- def pack_array(arr: np.ndarray, codec: blosc2.Codec = blosc2.Codec.ZSTD, clevel: int = 3) -> bytes:
51
- """Serialize a numpy array to a self-describing compressed byte string."""
50
+ # Fraction of each axis used for the center crop that the filter heuristic compresses.
51
+ _SELECT_FILTER_CROP_FRACTION = 0.25
52
+
53
+
54
+ def _compress_all(raw: memoryview, total: int, codec: blosc2.Codec, clevel: int, filters: list) -> int:
55
+ """Compressed byte length of ``raw`` under ``filters``, chunked exactly as pack_array does."""
56
+ size = 0
57
+ nchunks = (total + _CHUNK_SIZE - 1) // _CHUNK_SIZE
58
+ for i in range(nchunks):
59
+ start = i * _CHUNK_SIZE
60
+ end = min(start + _CHUNK_SIZE, total)
61
+ size += len(blosc2.compress2(raw[start:end], codec=codec, clevel=clevel, filters=filters))
62
+ return size
63
+
64
+
65
+ def _select_filter(arr: np.ndarray, codec: blosc2.Codec, clevel: int) -> "blosc2.Filter":
66
+ """Pick NOFILTER vs SHUFFLE for ``arr`` by trial-compressing a small centered crop.
67
+
68
+ Uses ``compress2`` on the raw bytes — exactly the path pack_array takes — so the decision
69
+ is consistent with how the whole array is actually compressed. The crop is
70
+ ``_SELECT_FILTER_CROP_FRACTION`` of each axis (centered), keeping the trial cheap and
71
+ representative (lands on foreground). Ties go to NOFILTER; any failure falls back to it.
72
+ """
73
+ try:
74
+ crop_shape = [max(1, int(s * _SELECT_FILTER_CROP_FRACTION)) for s in arr.shape]
75
+ slices = tuple(slice((s - cs) // 2, (s - cs) // 2 + cs) for s, cs in zip(arr.shape, crop_shape))
76
+ crop = np.ascontiguousarray(arr[slices])
77
+ raw = memoryview(crop).cast("B")
78
+ total = raw.nbytes
79
+
80
+ best_filter, best_bytes = blosc2.Filter.NOFILTER, None
81
+ for f in (blosc2.Filter.NOFILTER, blosc2.Filter.SHUFFLE):
82
+ cb = _compress_all(raw, total, codec, clevel, [f])
83
+ if best_bytes is None or cb < best_bytes:
84
+ best_bytes, best_filter = cb, f
85
+ return best_filter
86
+ except Exception as e:
87
+ from warnings import warn
88
+
89
+ warn(f"_select_filter failed ({e!r}); falling back to NOFILTER.")
90
+ return blosc2.Filter.NOFILTER
91
+
92
+
93
+ def pack_array(
94
+ arr: np.ndarray,
95
+ codec: blosc2.Codec = blosc2.Codec.ZSTD,
96
+ clevel: int = 3,
97
+ filters: Optional[list] = None,
98
+ ) -> bytes:
99
+ """Serialize a numpy array to a self-describing compressed byte string.
100
+
101
+ ``filters`` is the blosc2 filter pipeline to apply. If ``None`` (the default), the
102
+ better of NOFILTER/SHUFFLE is auto-selected by trial-compressing a cheap, representative
103
+ slab — appropriate for images, whose optimum depends on the data. Callers that already
104
+ know the optimum (interactions and segmentations compress best with NOFILTER) should pass
105
+ ``[blosc2.Filter.NOFILTER]`` to skip the selection. The chosen filter is self-describing
106
+ inside the blosc2 frame, so unpack_array (decompress2) needs no changes.
107
+ """
52
108
  arr = np.ascontiguousarray(arr)
53
109
  dtype_str = arr.dtype.str.lstrip("<>|=").encode("ascii")
54
110
  if arr.dtype.byteorder not in ("=", "|", "<"):
@@ -77,11 +133,17 @@ def pack_array(arr: np.ndarray, codec: blosc2.Codec = blosc2.Codec.ZSTD, clevel:
77
133
  raw = memoryview(arr).cast("B")
78
134
  total = raw.nbytes
79
135
  nchunks = (total + _CHUNK_SIZE - 1) // _CHUNK_SIZE
136
+
137
+ if filters is None:
138
+ # Auto-select the better filter from a small centered crop, using the same
139
+ # compress2 path as below for consistency.
140
+ filters = [_select_filter(arr, codec, clevel)]
141
+
80
142
  parts = [header, shape_bytes, struct.pack("<I", nchunks)]
81
143
  for i in range(nchunks):
82
144
  start = i * _CHUNK_SIZE
83
145
  end = min(start + _CHUNK_SIZE, total)
84
- chunk = blosc2.compress2(raw[start:end], codec=codec, clevel=clevel)
146
+ chunk = blosc2.compress2(raw[start:end], codec=codec, clevel=clevel, filters=filters)
85
147
  parts.append(struct.pack("<QQ", end - start, len(chunk)))
86
148
  parts.append(chunk)
87
149
  return b"".join(parts)