nnInteractive 2.0.0__tar.gz → 2.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nninteractive-2.0.0 → nninteractive-2.2.0}/PKG-INFO +75 -5
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/inference/inference_session.py +115 -11
- nninteractive-2.2.0/nnInteractive/inference/remote/__init__.py +11 -0
- nninteractive-2.2.0/nnInteractive/inference/remote/_protocol.py +27 -0
- nninteractive-2.2.0/nnInteractive/inference/remote/remote_session.py +528 -0
- nninteractive-2.2.0/nnInteractive/inference/remote/serialization.py +142 -0
- nninteractive-2.2.0/nnInteractive/inference/server/app.py +662 -0
- nninteractive-2.2.0/nnInteractive/inference/server/main.py +233 -0
- nninteractive-2.2.0/nnInteractive/utils/__init__.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive.egg-info/PKG-INFO +75 -5
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive.egg-info/SOURCES.txt +8 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive.egg-info/dependency_links.txt +0 -0
- nninteractive-2.2.0/nnInteractive.egg-info/entry_points.txt +2 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive.egg-info/requires.txt +8 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive.egg-info/top_level.txt +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/pyproject.toml +16 -1
- {nninteractive-2.0.0 → nninteractive-2.2.0}/readme.md +67 -4
- {nninteractive-2.0.0 → nninteractive-2.2.0}/LICENSE +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/__init__.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/inference/__init__.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +0 -0
- {nninteractive-2.0.0/nnInteractive/interaction → nninteractive-2.2.0/nnInteractive/inference/server}/__init__.py +0 -0
- {nninteractive-2.0.0/nnInteractive/trainer → nninteractive-2.2.0/nnInteractive/interaction}/__init__.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/interaction/point.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/setup.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/metadata.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/reader.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/run.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/__init__.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/__init__.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/setup.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/__init__.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/loss_fns.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/model/__init__.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/model/sam2.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/optimizer.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/train.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/trainer.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/logger.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/supervoxel.py +0 -0
- {nninteractive-2.0.0/nnInteractive/utils → nninteractive-2.2.0/nnInteractive/trainer}/__init__.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/trainer/nnInteractiveTrainer.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/utils/bboxes.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/utils/checkpoint_cleansing.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/utils/crop.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/utils/erosion_dilation.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/utils/inference_helpers.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/utils/os_shennanigans.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/nnInteractive/utils/rounding.py +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/setup.cfg +0 -0
- {nninteractive-2.0.0 → nninteractive-2.2.0}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nnInteractive
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.2.0
|
|
4
4
|
Summary: Inference code for nnInteractive
|
|
5
5
|
Author: Helmholtz Imaging Applied Computer Vision Lab
|
|
6
6
|
Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
|
|
@@ -223,10 +223,17 @@ Requires-Dist: nnunetv2>=2.7.0
|
|
|
223
223
|
Requires-Dist: torch!=2.9.*,>=2.1.2
|
|
224
224
|
Requires-Dist: acvl-utils<0.3,>=0.2.3
|
|
225
225
|
Requires-Dist: batchgenerators>=0.25.1
|
|
226
|
+
Requires-Dist: fastapi>=0.110
|
|
227
|
+
Requires-Dist: uvicorn[standard]>=0.27
|
|
228
|
+
Requires-Dist: httpx>=0.27
|
|
229
|
+
Requires-Dist: blosc2
|
|
226
230
|
Provides-Extra: dev
|
|
227
231
|
Requires-Dist: black; extra == "dev"
|
|
228
232
|
Requires-Dist: ruff; extra == "dev"
|
|
229
233
|
Requires-Dist: pre-commit; extra == "dev"
|
|
234
|
+
Provides-Extra: client
|
|
235
|
+
Requires-Dist: httpx>=0.27; extra == "client"
|
|
236
|
+
Requires-Dist: blosc2; extra == "client"
|
|
230
237
|
Dynamic: license-file
|
|
231
238
|
|
|
232
239
|
<img src="imgs/nnInteractive_header_white.png">
|
|
@@ -417,14 +424,42 @@ session.add_point_interaction(POINT_COORDINATES, include_interaction=False)
|
|
|
417
424
|
session.add_bbox_interaction(BBOX_COORDINATES, include_interaction=True)
|
|
418
425
|
|
|
419
426
|
# Example: Add a scribble interaction
|
|
420
|
-
# - A 3D image of the same shape as img where one slice (any axis-aligned orientation) contains a hand-drawn scribble.
|
|
421
427
|
# - Background must be 0, and scribble must be 1.
|
|
422
428
|
# - Use session.preferred_scribble_thickness for optimal results.
|
|
423
|
-
|
|
429
|
+
#
|
|
430
|
+
# ✅ RECOMMENDED (v2): pass a small 2D crop plus its location.
|
|
431
|
+
# Scribbles live on a single axis-aligned slice, so one of the three bbox
|
|
432
|
+
# dimensions is always size 1 and the in-plane extent typically covers only
|
|
433
|
+
# a small region. The cropped array is ORDERS OF MAGNITUDE
|
|
434
|
+
# smaller than a full-volume mask for typical annotations, which makes this
|
|
435
|
+
# path dramatically faster. Please prefer this
|
|
436
|
+
# form in new integrations.
|
|
437
|
+
#
|
|
438
|
+
# SCRIBBLE_CROP.shape must equal the bbox size. INTERACTION_BBOX uses
|
|
439
|
+
# half-open intervals [[x1,x2],[y1,y2],[z1,z2]] in original-image coordinates.
|
|
440
|
+
# Example: a scribble drawn on axial slice z=64, covering x∈[100,140), y∈[80,150):
|
|
441
|
+
# SCRIBBLE_CROP = <ndarray of shape (40, 70, 1), values 0 or 1>
|
|
442
|
+
# INTERACTION_BBOX = [[100, 140], [80, 150], [64, 65]]
|
|
443
|
+
session.add_scribble_interaction(
|
|
444
|
+
SCRIBBLE_CROP,
|
|
445
|
+
include_interaction=True,
|
|
446
|
+
interaction_bbox=INTERACTION_BBOX,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
# Legacy form (still supported, but discouraged): a 3D array matching the
|
|
450
|
+
# full original image shape with the scribble baked into one slice.
|
|
451
|
+
# session.add_scribble_interaction(SCRIBBLE_IMAGE, include_interaction=True)
|
|
424
452
|
|
|
425
453
|
# Example: Add a lasso interaction
|
|
426
|
-
# -
|
|
427
|
-
|
|
454
|
+
# - Like scribble but the single slice contains a **closed contour** for the selection.
|
|
455
|
+
# - Same recommendation applies: pass a 2D crop + interaction_bbox for a large speedup.
|
|
456
|
+
session.add_lasso_interaction(
|
|
457
|
+
LASSO_CROP,
|
|
458
|
+
include_interaction=True,
|
|
459
|
+
interaction_bbox=INTERACTION_BBOX,
|
|
460
|
+
)
|
|
461
|
+
# Legacy full-volume form (discouraged):
|
|
462
|
+
# session.add_lasso_interaction(LASSO_IMAGE, include_interaction=True)
|
|
428
463
|
|
|
429
464
|
# You can combine any number of interactions as needed.
|
|
430
465
|
# The model refines the segmentation result incrementally with each new interaction.
|
|
@@ -452,6 +487,41 @@ session.set_target_buffer(torch.zeros(NEW_IMAGE.shape[1:], dtype=torch.uint8))
|
|
|
452
487
|
# Enjoy!
|
|
453
488
|
```
|
|
454
489
|
|
|
490
|
+
## Running inference on a remote GPU (client / server)
|
|
491
|
+
|
|
492
|
+
If the machine running your GUI does not have a powerful GPU, you can run the
|
|
493
|
+
model on a remote box and drive it over HTTP with
|
|
494
|
+
**`nnInteractiveRemoteInferenceSession`** — a drop-in replacement with the same
|
|
495
|
+
public API as the local session. The server loads the model once at startup and
|
|
496
|
+
hosts multiple concurrent client sessions; each client keeps its own image,
|
|
497
|
+
target buffer, and interaction state.
|
|
498
|
+
|
|
499
|
+
Start the server on the GPU box:
|
|
500
|
+
|
|
501
|
+
```bash
|
|
502
|
+
nninteractive-server \
|
|
503
|
+
--model-dir /path/to/checkpoint_folder --fold all \
|
|
504
|
+
--host 0.0.0.0 --port 1527 \
|
|
505
|
+
--api-key "$(openssl rand -hex 32)"
|
|
506
|
+
```
|
|
507
|
+
|
|
508
|
+
And in the client code, swap the local session for the remote one:
|
|
509
|
+
|
|
510
|
+
```python
|
|
511
|
+
from nnInteractive.inference.remote import nnInteractiveRemoteInferenceSession
|
|
512
|
+
|
|
513
|
+
session = nnInteractiveRemoteInferenceSession(
|
|
514
|
+
server_url="http://gpu-box.lab:1527",
|
|
515
|
+
api_key="…",
|
|
516
|
+
)
|
|
517
|
+
# From here on, the API is identical to nnInteractiveInferenceSession.
|
|
518
|
+
```
|
|
519
|
+
|
|
520
|
+
For full details — installation, authentication, single-user SSH-tunnel setup,
|
|
521
|
+
multi-user deployment behind a reverse proxy, concurrency/session model, idle
|
|
522
|
+
expiry and heartbeats, GUI integration notes, and troubleshooting — see
|
|
523
|
+
[`SERVER_CLIENT.md`](SERVER_CLIENT.md).
|
|
524
|
+
|
|
455
525
|
## nnInteractive SuperVoxels
|
|
456
526
|
|
|
457
527
|
As part of the `nnInteractive` framework, we provide a dedicated module for **supervoxel generation** based on [SAM](https://github.com/facebookresearch/segment-anything) and [SAM2](https://github.com/facebookresearch/sam2). This replaces traditional superpixel methods (e.g., SLIC) with **foundation model–powered 3D pseudo-labels**.
|
|
@@ -50,14 +50,16 @@ class nnInteractiveInferenceSession:
|
|
|
50
50
|
):
|
|
51
51
|
"""
|
|
52
52
|
Only intended to work with nnInteractiveTrainerV2 and its derivatives
|
|
53
|
+
|
|
54
|
+
``use_torch_compile``: compile the network with ``torch.compile``. The
|
|
55
|
+
first prediction after enabling this is slow (compilation happens lazily
|
|
56
|
+
on the first forward pass), but every subsequent prediction is faster.
|
|
57
|
+
This is recommended for the persistent inference server, where the
|
|
58
|
+
process is long-lived so the one-time compile cost is paid only once and
|
|
59
|
+
amortized across the whole session lifetime.
|
|
53
60
|
"""
|
|
54
61
|
print("session initialized")
|
|
55
62
|
|
|
56
|
-
# set as part of initialization
|
|
57
|
-
assert use_torch_compile is False, (
|
|
58
|
-
"torch.compile is not supported. The blosc2-backed interaction tensor "
|
|
59
|
-
"requires numpy↔torch round-trips that break compile tracing."
|
|
60
|
-
)
|
|
61
63
|
self.network = None
|
|
62
64
|
self.label_manager = None
|
|
63
65
|
self.dataset_json = None
|
|
@@ -83,6 +85,10 @@ class nnInteractiveInferenceSession:
|
|
|
83
85
|
self.preprocessed_image: torch.Tensor = None
|
|
84
86
|
self.preprocessed_props = None
|
|
85
87
|
self.target_buffer: Union[np.ndarray, torch.Tensor] = None
|
|
88
|
+
# Bbox (in original-image coordinates) of the most recent target_buffer write.
|
|
89
|
+
# Captured inside _paste_prediction_to_target_buffer so remote callers can
|
|
90
|
+
# fetch just the touched region without diffing.
|
|
91
|
+
self._last_paste_bbox: Optional[List[List[int]]] = None
|
|
86
92
|
|
|
87
93
|
# this will be set when loading the model (initialize_from_trained_model_folder)
|
|
88
94
|
self.pad_mode_data = self.preferred_scribble_thickness = self.point_interaction = None
|
|
@@ -287,6 +293,7 @@ class nnInteractiveInferenceSession:
|
|
|
287
293
|
else:
|
|
288
294
|
pred_for_target = prediction.to("cpu")
|
|
289
295
|
paste_tensor(self.target_buffer, pred_for_target, target_bbox)
|
|
296
|
+
self._last_paste_bbox = target_bbox
|
|
290
297
|
|
|
291
298
|
def _estimate_refinement_cache_nbytes(self, cache_bbox: List[List[int]]) -> int:
|
|
292
299
|
cache_voxels = int(np.prod(self._bbox_size(cache_bbox), dtype=np.int64))
|
|
@@ -517,6 +524,7 @@ class nnInteractiveInferenceSession:
|
|
|
517
524
|
self.current_interaction_intensity = 1.0
|
|
518
525
|
empty_cache(self.device)
|
|
519
526
|
self.original_image_shape = None
|
|
527
|
+
self._last_paste_bbox = None
|
|
520
528
|
|
|
521
529
|
def _initialize_interactions(self, image_torch: torch.Tensor):
|
|
522
530
|
shape = (self.num_interaction_channels, *image_torch.shape[1:])
|
|
@@ -606,6 +614,7 @@ class nnInteractiveInferenceSession:
|
|
|
606
614
|
self.target_buffer.fill(0)
|
|
607
615
|
elif isinstance(self.target_buffer, torch.Tensor):
|
|
608
616
|
self.target_buffer.zero_()
|
|
617
|
+
self._last_paste_bbox = None
|
|
609
618
|
empty_cache(self.device)
|
|
610
619
|
|
|
611
620
|
def add_bbox_interaction(
|
|
@@ -876,6 +885,42 @@ class nnInteractiveInferenceSession:
|
|
|
876
885
|
else:
|
|
877
886
|
del initial_seg
|
|
878
887
|
|
|
888
|
+
@torch.inference_mode()
|
|
889
|
+
def warmup(self) -> bool:
|
|
890
|
+
"""Run a single dummy forward pass to trigger lazy ``torch.compile`` compilation up front.
|
|
891
|
+
|
|
892
|
+
With ``torch.compile`` enabled the network is compiled lazily on its first
|
|
893
|
+
forward pass, which would otherwise make the user's *first* real prediction
|
|
894
|
+
slow. Every prediction path — the initial coarse pass, the zoom-out
|
|
895
|
+
iterations, and the refinement patches — feeds the network an input of
|
|
896
|
+
identical shape ``[1, num_input_channels + num_interaction_channels,
|
|
897
|
+
*patch_size]`` (``_build_network_input`` always resizes the crop to
|
|
898
|
+
``patch_size``, and refinement crops at exactly ``patch_size``). So a single
|
|
899
|
+
dummy pass at that shape populates the compile cache and every subsequent
|
|
900
|
+
real prediction is fast.
|
|
901
|
+
|
|
902
|
+
Returns ``True`` if a warmup pass was run, ``False`` if it was a no-op
|
|
903
|
+
(network not compiled — there is nothing to pre-compile, so a dummy pass
|
|
904
|
+
would not save the user any time). Mirrors ``_predict``'s autocast/
|
|
905
|
+
inference-mode context and the float32 input dtype that ``torch.cat``
|
|
906
|
+
produces when concatenating the float32 image with the fp16 interactions.
|
|
907
|
+
"""
|
|
908
|
+
if self.network is None or self.configuration_manager is None:
|
|
909
|
+
raise RuntimeError("warmup() requires an initialized network; call initialize_* first")
|
|
910
|
+
if not isinstance(self.network, OptimizedModule):
|
|
911
|
+
return False
|
|
912
|
+
num_input_channels = (
|
|
913
|
+
determine_num_input_channels(self.plans_manager, self.configuration_manager, self.dataset_json)
|
|
914
|
+
+ self.num_interaction_channels
|
|
915
|
+
)
|
|
916
|
+
patch_size = self.configuration_manager.patch_size
|
|
917
|
+
dummy = torch.zeros((1, num_input_channels, *patch_size), dtype=torch.float32, device=self.device)
|
|
918
|
+
with torch.autocast(self.device.type, enabled=True) if self.device.type == "cuda" else dummy_context():
|
|
919
|
+
self.network(dummy)
|
|
920
|
+
del dummy
|
|
921
|
+
empty_cache(self.device)
|
|
922
|
+
return True
|
|
923
|
+
|
|
879
924
|
@torch.inference_mode()
|
|
880
925
|
def _predict(self, force_full_refine: bool = False):
|
|
881
926
|
"""
|
|
@@ -1287,6 +1332,36 @@ class nnInteractiveInferenceSession:
|
|
|
1287
1332
|
"""
|
|
1288
1333
|
This is used when making predictions with a trained model
|
|
1289
1334
|
"""
|
|
1335
|
+
artifacts = self._load_model_artifacts_from_disk(model_training_output_dir, use_fold, checkpoint_name)
|
|
1336
|
+
self.initialize_from_loaded_artifacts(artifacts)
|
|
1337
|
+
|
|
1338
|
+
def _load_model_artifacts_from_disk(
|
|
1339
|
+
self,
|
|
1340
|
+
model_training_output_dir: str,
|
|
1341
|
+
use_fold: Union[int, str] = None,
|
|
1342
|
+
checkpoint_name: str = "checkpoint_final.pth",
|
|
1343
|
+
) -> dict:
|
|
1344
|
+
"""Read all model artifacts from disk and build the network on ``self.device``.
|
|
1345
|
+
|
|
1346
|
+
Returns an artifact dict that can be applied to this or any other freshly
|
|
1347
|
+
constructed session via :meth:`initialize_from_loaded_artifacts`. The
|
|
1348
|
+
returned values are the actual objects (the ``nn.Module`` with its
|
|
1349
|
+
weights and buffers, the plans/configuration managers, the dataset
|
|
1350
|
+
json, the label manager) — not copies. Multiple sessions calling
|
|
1351
|
+
:meth:`initialize_from_loaded_artifacts` with the same dict will all
|
|
1352
|
+
end up with ``self.network`` pointing at the same module instance and
|
|
1353
|
+
the same weight tensors on the GPU. This is safe as long as callers
|
|
1354
|
+
treat these objects as read-only after construction; in the multi-
|
|
1355
|
+
session server that is enforced by running inference under
|
|
1356
|
+
``@torch.inference_mode()`` and serializing predict calls with a
|
|
1357
|
+
global GPU lock.
|
|
1358
|
+
|
|
1359
|
+
Note: this also mutates ``self`` (applies capability, sets pad/decay/
|
|
1360
|
+
thickness) because ``num_interaction_channels`` is required to build the
|
|
1361
|
+
network. The caller should follow up with
|
|
1362
|
+
:meth:`initialize_from_loaded_artifacts` (this is what
|
|
1363
|
+
:meth:`initialize_from_trained_model_folder` does).
|
|
1364
|
+
"""
|
|
1290
1365
|
point_interaction_use_etd = True
|
|
1291
1366
|
(
|
|
1292
1367
|
capability_content,
|
|
@@ -1353,12 +1428,41 @@ class nnInteractiveInferenceSession:
|
|
|
1353
1428
|
).to(self.device)
|
|
1354
1429
|
network.load_state_dict(parameters)
|
|
1355
1430
|
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1361
|
-
|
|
1431
|
+
return {
|
|
1432
|
+
"capability_content": capability_content,
|
|
1433
|
+
"point_interaction": self.point_interaction,
|
|
1434
|
+
"preferred_scribble_thickness": self.preferred_scribble_thickness,
|
|
1435
|
+
"interaction_decay": self.interaction_decay,
|
|
1436
|
+
"pad_mode_data": self.pad_mode_data,
|
|
1437
|
+
"network": network,
|
|
1438
|
+
"plans_manager": plans_manager,
|
|
1439
|
+
"configuration_manager": configuration_manager,
|
|
1440
|
+
"dataset_json": dataset_json,
|
|
1441
|
+
"trainer_name": trainer_name,
|
|
1442
|
+
"label_manager": plans_manager.get_label_manager(dataset_json),
|
|
1443
|
+
}
|
|
1444
|
+
|
|
1445
|
+
def initialize_from_loaded_artifacts(self, artifacts: dict):
|
|
1446
|
+
"""Apply pre-loaded artifacts to this session instance.
|
|
1447
|
+
|
|
1448
|
+
``artifacts`` is the dict returned by :meth:`_load_model_artifacts_from_disk`.
|
|
1449
|
+
Useful for spawning multiple sessions that share one loaded model (e.g.
|
|
1450
|
+
the multi-session inference server). All artifact entries — including
|
|
1451
|
+
``self.network`` — are stored by reference; passing the same dict to
|
|
1452
|
+
multiple sessions does not duplicate the network or its weights in
|
|
1453
|
+
memory.
|
|
1454
|
+
"""
|
|
1455
|
+
self.preferred_scribble_thickness = artifacts["preferred_scribble_thickness"]
|
|
1456
|
+
self.interaction_decay = artifacts["interaction_decay"]
|
|
1457
|
+
self.pad_mode_data = artifacts["pad_mode_data"]
|
|
1458
|
+
self.point_interaction = artifacts["point_interaction"]
|
|
1459
|
+
self._apply_capability(artifacts["capability_content"])
|
|
1460
|
+
self.plans_manager = artifacts["plans_manager"]
|
|
1461
|
+
self.configuration_manager = artifacts["configuration_manager"]
|
|
1462
|
+
self.network = artifacts["network"]
|
|
1463
|
+
self.dataset_json = artifacts["dataset_json"]
|
|
1464
|
+
self.trainer_name = artifacts["trainer_name"]
|
|
1465
|
+
self.label_manager = artifacts["label_manager"]
|
|
1362
1466
|
if self.use_torch_compile and not isinstance(self.network, OptimizedModule):
|
|
1363
1467
|
print("Using torch.compile")
|
|
1364
1468
|
self.network = torch.compile(self.network)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from nnInteractive.inference.remote.remote_session import (
|
|
2
|
+
ServerAtCapacityError,
|
|
3
|
+
SessionExpiredError,
|
|
4
|
+
nnInteractiveRemoteInferenceSession,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"nnInteractiveRemoteInferenceSession",
|
|
9
|
+
"SessionExpiredError",
|
|
10
|
+
"ServerAtCapacityError",
|
|
11
|
+
]
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Shared constants for the nnInteractive client/server HTTP protocol."""
|
|
2
|
+
|
|
3
|
+
# HTTP header used to carry JSON-encoded metadata alongside a binary array body.
|
|
4
|
+
META_HEADER = "X-Meta"
|
|
5
|
+
# HTTP header used to carry a per-client lease token identifying which session
|
|
6
|
+
# on the (multi-session) server the request applies to.
|
|
7
|
+
LEASE_HEADER = "X-Lease-Token"
|
|
8
|
+
|
|
9
|
+
# Endpoint paths.
|
|
10
|
+
PATH_HEALTHZ = "/healthz"
|
|
11
|
+
PATH_CAPABILITIES = "/capabilities"
|
|
12
|
+
PATH_CLAIM = "/claim"
|
|
13
|
+
PATH_RELEASE = "/release"
|
|
14
|
+
PATH_HEARTBEAT = "/heartbeat"
|
|
15
|
+
PATH_LEASE_STATUS = "/lease_status"
|
|
16
|
+
PATH_SET_IMAGE = "/set_image"
|
|
17
|
+
PATH_SET_TARGET_BUFFER = "/set_target_buffer"
|
|
18
|
+
PATH_RESET_INTERACTIONS = "/reset_interactions"
|
|
19
|
+
PATH_SET_DO_AUTOZOOM = "/set_do_autozoom"
|
|
20
|
+
PATH_ADD_BBOX = "/add_bbox_interaction"
|
|
21
|
+
PATH_ADD_POINT = "/add_point_interaction"
|
|
22
|
+
PATH_ADD_SCRIBBLE = "/add_scribble_interaction"
|
|
23
|
+
PATH_ADD_LASSO = "/add_lasso_interaction"
|
|
24
|
+
PATH_ADD_INITIAL_SEG = "/add_initial_seg_interaction"
|
|
25
|
+
|
|
26
|
+
# Body content type for endpoints that ship a packed numpy array.
|
|
27
|
+
CONTENT_TYPE_OCTET_STREAM = "application/octet-stream"
|