nnInteractive 2.1.0__tar.gz → 2.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nninteractive-2.1.0 → nninteractive-2.3.0}/PKG-INFO +7 -1
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/inference_session.py +118 -17
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/remote/remote_session.py +73 -7
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/remote/serialization.py +65 -3
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/server/app.py +160 -33
- nninteractive-2.3.0/nnInteractive/inference/server/main.py +233 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive.egg-info/PKG-INFO +7 -1
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive.egg-info/SOURCES.txt +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive.egg-info/dependency_links.txt +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive.egg-info/requires.txt +5 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive.egg-info/top_level.txt +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/pyproject.toml +10 -1
- {nninteractive-2.1.0 → nninteractive-2.3.0}/readme.md +2 -0
- nninteractive-2.1.0/nnInteractive/inference/server/main.py +0 -149
- {nninteractive-2.1.0 → nninteractive-2.3.0}/LICENSE +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/remote/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/remote/_protocol.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/server/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/interaction/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/interaction/point.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/setup.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/metadata.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/reader.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/run.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/setup.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/loss_fns.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/model/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/model/sam2.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/optimizer.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/train.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/trainer.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/utils/logger.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/supervoxel/src/supervoxel.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/trainer/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/trainer/nnInteractiveTrainer.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/utils/__init__.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/utils/bboxes.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/utils/checkpoint_cleansing.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/utils/crop.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/utils/erosion_dilation.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/utils/inference_helpers.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/utils/os_shennanigans.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/utils/rounding.py +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive.egg-info/entry_points.txt +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/setup.cfg +0 -0
- {nninteractive-2.1.0 → nninteractive-2.3.0}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nnInteractive
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.3.0
|
|
4
4
|
Summary: Inference code for nnInteractive
|
|
5
5
|
Author: Helmholtz Imaging Applied Computer Vision Lab
|
|
6
6
|
Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
|
|
@@ -226,10 +226,14 @@ Requires-Dist: batchgenerators>=0.25.1
|
|
|
226
226
|
Requires-Dist: fastapi>=0.110
|
|
227
227
|
Requires-Dist: uvicorn[standard]>=0.27
|
|
228
228
|
Requires-Dist: httpx>=0.27
|
|
229
|
+
Requires-Dist: blosc2
|
|
229
230
|
Provides-Extra: dev
|
|
230
231
|
Requires-Dist: black; extra == "dev"
|
|
231
232
|
Requires-Dist: ruff; extra == "dev"
|
|
232
233
|
Requires-Dist: pre-commit; extra == "dev"
|
|
234
|
+
Provides-Extra: client
|
|
235
|
+
Requires-Dist: httpx>=0.27; extra == "client"
|
|
236
|
+
Requires-Dist: blosc2; extra == "client"
|
|
233
237
|
Dynamic: license-file
|
|
234
238
|
|
|
235
239
|
<img src="imgs/nnInteractive_header_white.png">
|
|
@@ -545,6 +549,8 @@ Link: [](https
|
|
|
545
549
|
# License
|
|
546
550
|
Note that while this repository is available under Apache-2.0 license (see [LICENSE](./LICENSE)), the [model checkpoint](https://huggingface.co/nnInteractive/nnInteractive) is `Creative Commons Attribution Non Commercial Share Alike 4.0`!
|
|
547
551
|
|
|
552
|
+
Release model folders ship their own `LICENSE` file whose **first line is the license identifier** (e.g. `CC BY-NC-SA 4.0`); any following lines (such as a link to the full license) are ignored by the tool. At load time this first line is read and exposed as `session.license` so applications can display the model's license prominently. If a checkpoint folder has no `LICENSE` file, the official v1 checkpoint is assumed to be `CC BY-NC-SA 4.0` and any other checkpoint reports `!!MISSING!!`.
|
|
553
|
+
|
|
548
554
|
# Changelog
|
|
549
555
|
|
|
550
556
|
### 1.1.2 - 2025-08-02
|
|
@@ -50,14 +50,16 @@ class nnInteractiveInferenceSession:
|
|
|
50
50
|
):
|
|
51
51
|
"""
|
|
52
52
|
Only intended to work with nnInteractiveTrainerV2 and its derivatives
|
|
53
|
+
|
|
54
|
+
``use_torch_compile``: compile the network with ``torch.compile``. The
|
|
55
|
+
first prediction after enabling this is slow (compilation happens lazily
|
|
56
|
+
on the first forward pass), but every subsequent prediction is faster.
|
|
57
|
+
This is recommended for the persistent inference server, where the
|
|
58
|
+
process is long-lived so the one-time compile cost is paid only once and
|
|
59
|
+
amortized across the whole session lifetime.
|
|
53
60
|
"""
|
|
54
61
|
print("session initialized")
|
|
55
62
|
|
|
56
|
-
# set as part of initialization
|
|
57
|
-
assert use_torch_compile is False, (
|
|
58
|
-
"torch.compile is not supported. The blosc2-backed interaction tensor "
|
|
59
|
-
"requires numpy↔torch round-trips that break compile tracing."
|
|
60
|
-
)
|
|
61
63
|
self.network = None
|
|
62
64
|
self.label_manager = None
|
|
63
65
|
self.dataset_json = None
|
|
@@ -77,6 +79,11 @@ class nnInteractiveInferenceSession:
|
|
|
77
79
|
self.channel_mapping: dict = {}
|
|
78
80
|
self.supports_initial_label: bool = True
|
|
79
81
|
self.supports_zero_shot_label_refinement: bool = True
|
|
82
|
+
# License of the loaded model checkpoint. Set when the model is loaded
|
|
83
|
+
# (read from the LICENSE file in the checkpoint folder, or derived for
|
|
84
|
+
# legacy checkpoints without one). Exposed so GUIs can display it once
|
|
85
|
+
# the session is initialized. "!!MISSING!!" means the license is unknown.
|
|
86
|
+
self.license: Optional[str] = None
|
|
80
87
|
|
|
81
88
|
# image specific
|
|
82
89
|
self.interactions = None # blosc2.NDArray once initialized
|
|
@@ -116,6 +123,31 @@ class nnInteractiveInferenceSession:
|
|
|
116
123
|
and checkpoint.get("init_args", {}).get("configuration") == "3d_fullres_ps192_bs24"
|
|
117
124
|
)
|
|
118
125
|
|
|
126
|
+
@classmethod
|
|
127
|
+
def _load_license(cls, model_training_output_dir: str, plans: dict, checkpoint: dict) -> str:
|
|
128
|
+
"""Determine the license of the model being loaded.
|
|
129
|
+
|
|
130
|
+
Reads the ``LICENSE`` file from the checkpoint folder if present.
|
|
131
|
+
Expected format: the FIRST line is a short license identifier (e.g.
|
|
132
|
+
``CC BY-NC-SA 4.0``); any following lines (URL, full text, …) are for
|
|
133
|
+
human readers and are ignored. Only the first non-empty line is
|
|
134
|
+
returned, so ``self.license`` stays a short, displayable string.
|
|
135
|
+
|
|
136
|
+
If the folder has no ``LICENSE`` file it is most likely a legacy model:
|
|
137
|
+
the official v1 checkpoint is CC BY-NC-SA 4.0, anything else is reported
|
|
138
|
+
as ``"!!MISSING!!"`` so callers (e.g. GUIs) can flag the unknown license.
|
|
139
|
+
"""
|
|
140
|
+
license_file = join(model_training_output_dir, "LICENSE")
|
|
141
|
+
if isfile(license_file):
|
|
142
|
+
with open(license_file, "r", encoding="utf-8") as f:
|
|
143
|
+
for line in f:
|
|
144
|
+
line = line.strip()
|
|
145
|
+
if line:
|
|
146
|
+
return line
|
|
147
|
+
if cls._is_official_checkpoint(plans, checkpoint):
|
|
148
|
+
return "CC BY-NC-SA 4.0"
|
|
149
|
+
return "!!MISSING!!"
|
|
150
|
+
|
|
119
151
|
def _legacy_default_capability(self) -> dict:
|
|
120
152
|
return {
|
|
121
153
|
"supported_interactions": {
|
|
@@ -533,7 +565,13 @@ class nnInteractiveInferenceSession:
|
|
|
533
565
|
dtype=np.float16,
|
|
534
566
|
chunks=(1, *[min(64, s) for s in shape[1:]]),
|
|
535
567
|
blocks=(1, *[min(32, s) for s in shape[1:]]),
|
|
536
|
-
|
|
568
|
+
# Interactions compress better with NOFILTER, which is also faster than SHUFFLE.
|
|
569
|
+
cparams={
|
|
570
|
+
"codec": blosc2.Codec.LZ4,
|
|
571
|
+
"clevel": 5,
|
|
572
|
+
"filters": [blosc2.Filter.NOFILTER],
|
|
573
|
+
"nthreads": min(self.torch_n_threads, os.cpu_count()),
|
|
574
|
+
},
|
|
537
575
|
dparams={"nthreads": 4},
|
|
538
576
|
)
|
|
539
577
|
self._interactions_shape = shape
|
|
@@ -602,7 +640,13 @@ class nnInteractiveInferenceSession:
|
|
|
602
640
|
dtype=np.float16,
|
|
603
641
|
chunks=(1, *[min(64, s) for s in self._interactions_shape[1:]]),
|
|
604
642
|
blocks=(1, *[min(32, s) for s in self._interactions_shape[1:]]),
|
|
605
|
-
|
|
643
|
+
# Interactions compress better with NOFILTER, which is also faster than SHUFFLE.
|
|
644
|
+
cparams={
|
|
645
|
+
"codec": blosc2.Codec.LZ4,
|
|
646
|
+
"clevel": 5,
|
|
647
|
+
"filters": [blosc2.Filter.NOFILTER],
|
|
648
|
+
"nthreads": os.cpu_count(),
|
|
649
|
+
},
|
|
606
650
|
dparams={"nthreads": 4},
|
|
607
651
|
)
|
|
608
652
|
self.current_interaction_intensity = 1.0
|
|
@@ -883,6 +927,42 @@ class nnInteractiveInferenceSession:
|
|
|
883
927
|
else:
|
|
884
928
|
del initial_seg
|
|
885
929
|
|
|
930
|
+
@torch.inference_mode()
|
|
931
|
+
def warmup(self) -> bool:
|
|
932
|
+
"""Run a single dummy forward pass to trigger lazy ``torch.compile`` compilation up front.
|
|
933
|
+
|
|
934
|
+
With ``torch.compile`` enabled the network is compiled lazily on its first
|
|
935
|
+
forward pass, which would otherwise make the user's *first* real prediction
|
|
936
|
+
slow. Every prediction path — the initial coarse pass, the zoom-out
|
|
937
|
+
iterations, and the refinement patches — feeds the network an input of
|
|
938
|
+
identical shape ``[1, num_input_channels + num_interaction_channels,
|
|
939
|
+
*patch_size]`` (``_build_network_input`` always resizes the crop to
|
|
940
|
+
``patch_size``, and refinement crops at exactly ``patch_size``). So a single
|
|
941
|
+
dummy pass at that shape populates the compile cache and every subsequent
|
|
942
|
+
real prediction is fast.
|
|
943
|
+
|
|
944
|
+
Returns ``True`` if a warmup pass was run, ``False`` if it was a no-op
|
|
945
|
+
(network not compiled — there is nothing to pre-compile, so a dummy pass
|
|
946
|
+
would not save the user any time). Mirrors ``_predict``'s autocast/
|
|
947
|
+
inference-mode context and the float32 input dtype that ``torch.cat``
|
|
948
|
+
produces when concatenating the float32 image with the fp16 interactions.
|
|
949
|
+
"""
|
|
950
|
+
if self.network is None or self.configuration_manager is None:
|
|
951
|
+
raise RuntimeError("warmup() requires an initialized network; call initialize_* first")
|
|
952
|
+
if not isinstance(self.network, OptimizedModule):
|
|
953
|
+
return False
|
|
954
|
+
num_input_channels = (
|
|
955
|
+
determine_num_input_channels(self.plans_manager, self.configuration_manager, self.dataset_json)
|
|
956
|
+
+ self.num_interaction_channels
|
|
957
|
+
)
|
|
958
|
+
patch_size = self.configuration_manager.patch_size
|
|
959
|
+
dummy = torch.zeros((1, num_input_channels, *patch_size), dtype=torch.float32, device=self.device)
|
|
960
|
+
with torch.autocast(self.device.type, enabled=True) if self.device.type == "cuda" else dummy_context():
|
|
961
|
+
self.network(dummy)
|
|
962
|
+
del dummy
|
|
963
|
+
empty_cache(self.device)
|
|
964
|
+
return True
|
|
965
|
+
|
|
886
966
|
@torch.inference_mode()
|
|
887
967
|
def _predict(self, force_full_refine: bool = False):
|
|
888
968
|
"""
|
|
@@ -1296,6 +1376,16 @@ class nnInteractiveInferenceSession:
|
|
|
1296
1376
|
"""
|
|
1297
1377
|
artifacts = self._load_model_artifacts_from_disk(model_training_output_dir, use_fold, checkpoint_name)
|
|
1298
1378
|
self.initialize_from_loaded_artifacts(artifacts)
|
|
1379
|
+
# With torch.compile the network is compiled lazily on the first forward pass. For a
|
|
1380
|
+
# locally hosted model that lag would otherwise surface on the user's first real
|
|
1381
|
+
# prediction, where it is far more noticeable than during initialization. Trigger the
|
|
1382
|
+
# compilation now with a dummy forward pass so the cost is paid here instead. warmup()
|
|
1383
|
+
# is a no-op when the network is not compiled. The server takes care of its own warmup
|
|
1384
|
+
# explicitly (it shares one compiled network across sessions via
|
|
1385
|
+
# initialize_from_loaded_artifacts), so we only do this on the direct, local entry point.
|
|
1386
|
+
if self.use_torch_compile:
|
|
1387
|
+
print("torch.compile enabled; warming up (compiling) the network now (this is slow once)...")
|
|
1388
|
+
self.warmup()
|
|
1299
1389
|
|
|
1300
1390
|
def _load_model_artifacts_from_disk(
|
|
1301
1391
|
self,
|
|
@@ -1307,9 +1397,16 @@ class nnInteractiveInferenceSession:
|
|
|
1307
1397
|
|
|
1308
1398
|
Returns an artifact dict that can be applied to this or any other freshly
|
|
1309
1399
|
constructed session via :meth:`initialize_from_loaded_artifacts`. The
|
|
1310
|
-
returned
|
|
1311
|
-
|
|
1312
|
-
|
|
1400
|
+
returned values are the actual objects (the ``nn.Module`` with its
|
|
1401
|
+
weights and buffers, the plans/configuration managers, the dataset
|
|
1402
|
+
json, the label manager) — not copies. Multiple sessions calling
|
|
1403
|
+
:meth:`initialize_from_loaded_artifacts` with the same dict will all
|
|
1404
|
+
end up with ``self.network`` pointing at the same module instance and
|
|
1405
|
+
the same weight tensors on the GPU. This is safe as long as callers
|
|
1406
|
+
treat these objects as read-only after construction; in the multi-
|
|
1407
|
+
session server that is enforced by running inference under
|
|
1408
|
+
``@torch.inference_mode()`` and serializing predict calls with a
|
|
1409
|
+
global GPU lock.
|
|
1313
1410
|
|
|
1314
1411
|
Note: this also mutates ``self`` (applies capability, sets pad/decay/
|
|
1315
1412
|
thickness) because ``num_interaction_channels`` is required to build the
|
|
@@ -1344,12 +1441,11 @@ class nnInteractiveInferenceSession:
|
|
|
1344
1441
|
checkpoint = torch.load(
|
|
1345
1442
|
join(model_training_output_dir, fold_folder, checkpoint_name), map_location=self.device, weights_only=False
|
|
1346
1443
|
)
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
)
|
|
1444
|
+
self.license = self._load_license(model_training_output_dir, plans, checkpoint)
|
|
1445
|
+
print("=" * 80)
|
|
1446
|
+
print("Model license:")
|
|
1447
|
+
print(self.license)
|
|
1448
|
+
print("=" * 80)
|
|
1353
1449
|
trainer_name = checkpoint["trainer_name"]
|
|
1354
1450
|
configuration_name = checkpoint["init_args"]["configuration"]
|
|
1355
1451
|
|
|
@@ -1395,6 +1491,7 @@ class nnInteractiveInferenceSession:
|
|
|
1395
1491
|
"dataset_json": dataset_json,
|
|
1396
1492
|
"trainer_name": trainer_name,
|
|
1397
1493
|
"label_manager": plans_manager.get_label_manager(dataset_json),
|
|
1494
|
+
"license": self.license,
|
|
1398
1495
|
}
|
|
1399
1496
|
|
|
1400
1497
|
def initialize_from_loaded_artifacts(self, artifacts: dict):
|
|
@@ -1402,7 +1499,10 @@ class nnInteractiveInferenceSession:
|
|
|
1402
1499
|
|
|
1403
1500
|
``artifacts`` is the dict returned by :meth:`_load_model_artifacts_from_disk`.
|
|
1404
1501
|
Useful for spawning multiple sessions that share one loaded model (e.g.
|
|
1405
|
-
the multi-session inference server).
|
|
1502
|
+
the multi-session inference server). All artifact entries — including
|
|
1503
|
+
``self.network`` — are stored by reference; passing the same dict to
|
|
1504
|
+
multiple sessions does not duplicate the network or its weights in
|
|
1505
|
+
memory.
|
|
1406
1506
|
"""
|
|
1407
1507
|
self.preferred_scribble_thickness = artifacts["preferred_scribble_thickness"]
|
|
1408
1508
|
self.interaction_decay = artifacts["interaction_decay"]
|
|
@@ -1415,6 +1515,7 @@ class nnInteractiveInferenceSession:
|
|
|
1415
1515
|
self.dataset_json = artifacts["dataset_json"]
|
|
1416
1516
|
self.trainer_name = artifacts["trainer_name"]
|
|
1417
1517
|
self.label_manager = artifacts["label_manager"]
|
|
1518
|
+
self.license = artifacts["license"]
|
|
1418
1519
|
if self.use_torch_compile and not isinstance(self.network, OptimizedModule):
|
|
1419
1520
|
print("Using torch.compile")
|
|
1420
1521
|
self.network = torch.compile(self.network)
|
{nninteractive-2.1.0 → nninteractive-2.3.0}/nnInteractive/inference/remote/remote_session.py
RENAMED
|
@@ -10,9 +10,11 @@ from __future__ import annotations
|
|
|
10
10
|
|
|
11
11
|
import json
|
|
12
12
|
import os
|
|
13
|
+
import threading
|
|
13
14
|
import warnings
|
|
14
15
|
from typing import List, Optional, Tuple, Union
|
|
15
16
|
|
|
17
|
+
import blosc2
|
|
16
18
|
import httpx
|
|
17
19
|
import numpy as np
|
|
18
20
|
import torch
|
|
@@ -165,8 +167,16 @@ class nnInteractiveRemoteInferenceSession:
|
|
|
165
167
|
claim_info = claim_resp.json()
|
|
166
168
|
self._lease_token = claim_info["lease_token"]
|
|
167
169
|
self.idle_timeout_seconds: float = float(claim_info.get("idle_timeout_seconds", 0.0))
|
|
170
|
+
self.liveness_timeout_seconds: float = float(claim_info.get("liveness_timeout_seconds", 0.0))
|
|
168
171
|
self._http.headers[LEASE_HEADER] = self._lease_token
|
|
169
172
|
|
|
173
|
+
# Background liveness heartbeat bookkeeping. Defined before any code that
|
|
174
|
+
# might raise so close()/__del__ can always reference them safely. The
|
|
175
|
+
# thread itself is started at the end of __init__, once construction has
|
|
176
|
+
# fully succeeded.
|
|
177
|
+
self._stop_heartbeat = threading.Event()
|
|
178
|
+
self._heartbeat_thread: Optional[threading.Thread] = None
|
|
179
|
+
|
|
170
180
|
caps = self._get_json(PATH_CAPABILITIES)
|
|
171
181
|
|
|
172
182
|
# Attributes that mirror the local session so the GUI can introspect them
|
|
@@ -182,11 +192,27 @@ class nnInteractiveRemoteInferenceSession:
|
|
|
182
192
|
self.preferred_scribble_thickness = caps["preferred_scribble_thickness"]
|
|
183
193
|
self.interaction_decay = caps["interaction_decay"]
|
|
184
194
|
self.INFERENCE_SESSION_VERSION = caps["inference_session_version"]
|
|
195
|
+
# License of the model loaded on the server. Mirrors
|
|
196
|
+
# nnInteractiveInferenceSession.license so a GUI can display it
|
|
197
|
+
# regardless of whether it holds a local or remote session.
|
|
198
|
+
# "!!MISSING!!" means the server could not determine the license.
|
|
199
|
+
self.license: Optional[str] = caps.get("license")
|
|
185
200
|
|
|
186
201
|
self.original_image_shape: Optional[Tuple[int, ...]] = None
|
|
187
202
|
self.target_buffer: Union[np.ndarray, torch.Tensor, None] = None
|
|
188
203
|
self.do_autozoom: bool = bool(caps.get("do_autozoom", True))
|
|
189
204
|
|
|
205
|
+
# Construction succeeded — start auto-heartbeating to keep the server
|
|
206
|
+
# from reaping us as a dead client. Beat at half the liveness timeout so
|
|
207
|
+
# one dropped request still leaves margin. Daemon thread: it never blocks
|
|
208
|
+
# interpreter exit, and close() joins it cleanly.
|
|
209
|
+
if self.liveness_timeout_seconds > 0:
|
|
210
|
+
self._heartbeat_interval = max(5.0, self.liveness_timeout_seconds / 2.0)
|
|
211
|
+
self._heartbeat_thread = threading.Thread(
|
|
212
|
+
target=self._heartbeat_loop, name="nnInteractive-heartbeat", daemon=True
|
|
213
|
+
)
|
|
214
|
+
self._heartbeat_thread.start()
|
|
215
|
+
|
|
190
216
|
# ------------------------------------------------------------------ #
|
|
191
217
|
# HTTP helpers (private) #
|
|
192
218
|
# ------------------------------------------------------------------ #
|
|
@@ -272,7 +298,10 @@ class nnInteractiveRemoteInferenceSession:
|
|
|
272
298
|
self.target_buffer = target_buffer
|
|
273
299
|
self._post_json(
|
|
274
300
|
PATH_SET_TARGET_BUFFER,
|
|
275
|
-
{
|
|
301
|
+
{
|
|
302
|
+
"shape": list(target_buffer.shape),
|
|
303
|
+
"dtype": _buffer_dtype_str(target_buffer),
|
|
304
|
+
},
|
|
276
305
|
)
|
|
277
306
|
|
|
278
307
|
def set_do_autozoom(self, do_autozoom: bool) -> None:
|
|
@@ -372,7 +401,8 @@ class nnInteractiveRemoteInferenceSession:
|
|
|
372
401
|
"override_capability_checks": bool(override_capability_checks),
|
|
373
402
|
"interaction_bbox": ([list(b) for b in interaction_bbox] if interaction_bbox is not None else None),
|
|
374
403
|
}
|
|
375
|
-
|
|
404
|
+
# Interactions (scribble/lasso masks) compress best with NOFILTER; skip auto-selection.
|
|
405
|
+
resp = self._post_binary(path, meta, pack_array(mask_image, filters=[blosc2.Filter.NOFILTER]))
|
|
376
406
|
self._apply_prediction_response(resp)
|
|
377
407
|
|
|
378
408
|
def add_initial_seg_interaction(
|
|
@@ -397,7 +427,8 @@ class nnInteractiveRemoteInferenceSession:
|
|
|
397
427
|
"run_prediction": bool(run_prediction),
|
|
398
428
|
"override_capability_checks": bool(override_capability_checks),
|
|
399
429
|
}
|
|
400
|
-
|
|
430
|
+
# Segmentations compress best with NOFILTER; skip auto-selection.
|
|
431
|
+
resp = self._post_binary(PATH_ADD_INITIAL_SEG, meta, pack_array(initial_seg, filters=[blosc2.Filter.NOFILTER]))
|
|
401
432
|
self._apply_prediction_response(resp)
|
|
402
433
|
|
|
403
434
|
# ------------------------------------------------------------------ #
|
|
@@ -420,17 +451,44 @@ class nnInteractiveRemoteInferenceSession:
|
|
|
420
451
|
return False
|
|
421
452
|
|
|
422
453
|
def heartbeat(self) -> float:
|
|
423
|
-
"""
|
|
454
|
+
"""Tell the server this client is still alive. Returns remaining seconds
|
|
455
|
+
until the *idle* timeout.
|
|
456
|
+
|
|
457
|
+
This proves liveness only: it stops the server from reaping the session
|
|
458
|
+
as a crashed/dead client, but it does NOT postpone the idle timeout —
|
|
459
|
+
that is refreshed solely by real user actions (``set_image``,
|
|
460
|
+
``add_*_interaction``, …). A session left untouched will therefore still
|
|
461
|
+
be reaped at the idle timeout even while heartbeats keep arriving.
|
|
424
462
|
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
after the configured idle timeout.
|
|
463
|
+
You normally never call this yourself: the session auto-heartbeats from
|
|
464
|
+
a background thread for the lifetime of the object.
|
|
428
465
|
"""
|
|
429
466
|
resp = self._http.post(PATH_HEARTBEAT)
|
|
430
467
|
_raise_for_lease_errors(resp)
|
|
431
468
|
resp.raise_for_status()
|
|
432
469
|
return float(resp.json().get("remaining_seconds", 0.0))
|
|
433
470
|
|
|
471
|
+
def _heartbeat_loop(self) -> None:
|
|
472
|
+
"""Background daemon: prove liveness every ``_heartbeat_interval`` seconds.
|
|
473
|
+
|
|
474
|
+
Stops when the session is closed or once the lease is gone. Transient
|
|
475
|
+
network errors are swallowed so a brief blip doesn't kill the heartbeat;
|
|
476
|
+
the server's liveness timeout tolerates a few missed beats. Lease expiry
|
|
477
|
+
(idle reap or server restart) is surfaced to the user on their next real
|
|
478
|
+
call, not from this thread.
|
|
479
|
+
"""
|
|
480
|
+
while not self._stop_heartbeat.wait(self._heartbeat_interval):
|
|
481
|
+
try:
|
|
482
|
+
self.heartbeat()
|
|
483
|
+
except SessionExpiredError:
|
|
484
|
+
break
|
|
485
|
+
except httpx.HTTPError:
|
|
486
|
+
continue
|
|
487
|
+
except Exception:
|
|
488
|
+
# Never let the daemon thread die noisily (e.g. client closing
|
|
489
|
+
# concurrently). Bail out quietly.
|
|
490
|
+
break
|
|
491
|
+
|
|
434
492
|
def lease_status(self) -> dict:
|
|
435
493
|
"""Read-only probe: how much time is left before this session is reaped?
|
|
436
494
|
|
|
@@ -444,6 +502,14 @@ class nnInteractiveRemoteInferenceSession:
|
|
|
444
502
|
return resp.json()
|
|
445
503
|
|
|
446
504
|
def close(self) -> None:
|
|
505
|
+
# Stop the heartbeat thread first so it can't use self._http after we
|
|
506
|
+
# close it. join() with a short timeout: the thread spends almost all
|
|
507
|
+
# its time in Event.wait(), which the set() interrupts immediately.
|
|
508
|
+
self._stop_heartbeat.set()
|
|
509
|
+
if self._heartbeat_thread is not None:
|
|
510
|
+
self._heartbeat_thread.join(timeout=5.0)
|
|
511
|
+
self._heartbeat_thread = None
|
|
512
|
+
|
|
447
513
|
# Best-effort release so the server can free our slot for other users
|
|
448
514
|
# without waiting for the idle reaper. Swallow errors: the server may
|
|
449
515
|
# already be gone, our lease may already be expired, etc. close()
|
|
@@ -47,8 +47,64 @@ _CODEC_ID = {
|
|
|
47
47
|
_ID_CODEC = {v: k for k, v in _CODEC_ID.items()}
|
|
48
48
|
|
|
49
49
|
|
|
50
|
-
|
|
51
|
-
|
|
50
|
+
# Fraction of each axis used for the center crop that the filter heuristic compresses.
|
|
51
|
+
_SELECT_FILTER_CROP_FRACTION = 0.25
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _compress_all(raw: memoryview, total: int, codec: blosc2.Codec, clevel: int, filters: list) -> int:
|
|
55
|
+
"""Compressed byte length of ``raw`` under ``filters``, chunked exactly as pack_array does."""
|
|
56
|
+
size = 0
|
|
57
|
+
nchunks = (total + _CHUNK_SIZE - 1) // _CHUNK_SIZE
|
|
58
|
+
for i in range(nchunks):
|
|
59
|
+
start = i * _CHUNK_SIZE
|
|
60
|
+
end = min(start + _CHUNK_SIZE, total)
|
|
61
|
+
size += len(blosc2.compress2(raw[start:end], codec=codec, clevel=clevel, filters=filters))
|
|
62
|
+
return size
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _select_filter(arr: np.ndarray, codec: blosc2.Codec, clevel: int) -> "blosc2.Filter":
|
|
66
|
+
"""Pick NOFILTER vs SHUFFLE for ``arr`` by trial-compressing a small centered crop.
|
|
67
|
+
|
|
68
|
+
Uses ``compress2`` on the raw bytes — exactly the path pack_array takes — so the decision
|
|
69
|
+
is consistent with how the whole array is actually compressed. The crop is
|
|
70
|
+
``_SELECT_FILTER_CROP_FRACTION`` of each axis (centered), keeping the trial cheap and
|
|
71
|
+
representative (lands on foreground). Ties go to NOFILTER; any failure falls back to it.
|
|
72
|
+
"""
|
|
73
|
+
try:
|
|
74
|
+
crop_shape = [max(1, int(s * _SELECT_FILTER_CROP_FRACTION)) for s in arr.shape]
|
|
75
|
+
slices = tuple(slice((s - cs) // 2, (s - cs) // 2 + cs) for s, cs in zip(arr.shape, crop_shape))
|
|
76
|
+
crop = np.ascontiguousarray(arr[slices])
|
|
77
|
+
raw = memoryview(crop).cast("B")
|
|
78
|
+
total = raw.nbytes
|
|
79
|
+
|
|
80
|
+
best_filter, best_bytes = blosc2.Filter.NOFILTER, None
|
|
81
|
+
for f in (blosc2.Filter.NOFILTER, blosc2.Filter.SHUFFLE):
|
|
82
|
+
cb = _compress_all(raw, total, codec, clevel, [f])
|
|
83
|
+
if best_bytes is None or cb < best_bytes:
|
|
84
|
+
best_bytes, best_filter = cb, f
|
|
85
|
+
return best_filter
|
|
86
|
+
except Exception as e:
|
|
87
|
+
from warnings import warn
|
|
88
|
+
|
|
89
|
+
warn(f"_select_filter failed ({e!r}); falling back to NOFILTER.")
|
|
90
|
+
return blosc2.Filter.NOFILTER
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def pack_array(
|
|
94
|
+
arr: np.ndarray,
|
|
95
|
+
codec: blosc2.Codec = blosc2.Codec.ZSTD,
|
|
96
|
+
clevel: int = 3,
|
|
97
|
+
filters: Optional[list] = None,
|
|
98
|
+
) -> bytes:
|
|
99
|
+
"""Serialize a numpy array to a self-describing compressed byte string.
|
|
100
|
+
|
|
101
|
+
``filters`` is the blosc2 filter pipeline to apply. If ``None`` (the default), the
|
|
102
|
+
better of NOFILTER/SHUFFLE is auto-selected by trial-compressing a cheap, representative
|
|
103
|
+
slab — appropriate for images, whose optimum depends on the data. Callers that already
|
|
104
|
+
know the optimum (interactions and segmentations compress best with NOFILTER) should pass
|
|
105
|
+
``[blosc2.Filter.NOFILTER]`` to skip the selection. The chosen filter is self-describing
|
|
106
|
+
inside the blosc2 frame, so unpack_array (decompress2) needs no changes.
|
|
107
|
+
"""
|
|
52
108
|
arr = np.ascontiguousarray(arr)
|
|
53
109
|
dtype_str = arr.dtype.str.lstrip("<>|=").encode("ascii")
|
|
54
110
|
if arr.dtype.byteorder not in ("=", "|", "<"):
|
|
@@ -77,11 +133,17 @@ def pack_array(arr: np.ndarray, codec: blosc2.Codec = blosc2.Codec.ZSTD, clevel:
|
|
|
77
133
|
raw = memoryview(arr).cast("B")
|
|
78
134
|
total = raw.nbytes
|
|
79
135
|
nchunks = (total + _CHUNK_SIZE - 1) // _CHUNK_SIZE
|
|
136
|
+
|
|
137
|
+
if filters is None:
|
|
138
|
+
# Auto-select the better filter from a small centered crop, using the same
|
|
139
|
+
# compress2 path as below for consistency.
|
|
140
|
+
filters = [_select_filter(arr, codec, clevel)]
|
|
141
|
+
|
|
80
142
|
parts = [header, shape_bytes, struct.pack("<I", nchunks)]
|
|
81
143
|
for i in range(nchunks):
|
|
82
144
|
start = i * _CHUNK_SIZE
|
|
83
145
|
end = min(start + _CHUNK_SIZE, total)
|
|
84
|
-
chunk = blosc2.compress2(raw[start:end], codec=codec, clevel=clevel)
|
|
146
|
+
chunk = blosc2.compress2(raw[start:end], codec=codec, clevel=clevel, filters=filters)
|
|
85
147
|
parts.append(struct.pack("<QQ", end - start, len(chunk)))
|
|
86
148
|
parts.append(chunk)
|
|
87
149
|
return b"".join(parts)
|