nnInteractive 2.1.0__tar.gz → 2.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. {nninteractive-2.1.0 → nninteractive-2.2.0}/PKG-INFO +5 -1
  2. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/inference/inference_session.py +57 -9
  3. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/inference/remote/remote_session.py +63 -5
  4. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/inference/server/app.py +105 -30
  5. nninteractive-2.2.0/nnInteractive/inference/server/main.py +233 -0
  6. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive.egg-info/PKG-INFO +5 -1
  7. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive.egg-info/SOURCES.txt +0 -0
  8. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive.egg-info/dependency_links.txt +0 -0
  9. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive.egg-info/requires.txt +5 -0
  10. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive.egg-info/top_level.txt +0 -0
  11. {nninteractive-2.1.0 → nninteractive-2.2.0}/pyproject.toml +10 -1
  12. nninteractive-2.1.0/nnInteractive/inference/server/main.py +0 -149
  13. {nninteractive-2.1.0 → nninteractive-2.2.0}/LICENSE +0 -0
  14. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/__init__.py +0 -0
  15. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/inference/__init__.py +0 -0
  16. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
  17. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +0 -0
  18. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/inference/remote/__init__.py +0 -0
  19. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/inference/remote/_protocol.py +0 -0
  20. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/inference/remote/serialization.py +0 -0
  21. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/inference/server/__init__.py +0 -0
  22. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/interaction/__init__.py +0 -0
  23. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/interaction/point.py +0 -0
  24. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/setup.py +0 -0
  25. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/metadata.py +0 -0
  26. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/reader.py +0 -0
  27. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/run.py +0 -0
  28. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/__init__.py +0 -0
  29. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/__init__.py +0 -0
  30. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +0 -0
  31. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +0 -0
  32. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +0 -0
  33. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +0 -0
  34. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +0 -0
  35. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +0 -0
  36. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +0 -0
  37. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +0 -0
  38. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +0 -0
  39. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +0 -0
  40. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +0 -0
  41. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +0 -0
  42. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +0 -0
  43. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +0 -0
  44. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +0 -0
  45. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +0 -0
  46. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +0 -0
  47. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +0 -0
  48. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +0 -0
  49. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +0 -0
  50. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +0 -0
  51. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +0 -0
  52. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +0 -0
  53. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +0 -0
  54. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/setup.py +0 -0
  55. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/__init__.py +0 -0
  56. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +0 -0
  57. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +0 -0
  58. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +0 -0
  59. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +0 -0
  60. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +0 -0
  61. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +0 -0
  62. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +0 -0
  63. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +0 -0
  64. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/loss_fns.py +0 -0
  65. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/model/__init__.py +0 -0
  66. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/model/sam2.py +0 -0
  67. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/optimizer.py +0 -0
  68. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +0 -0
  69. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/train.py +0 -0
  70. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/trainer.py +0 -0
  71. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +0 -0
  72. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +0 -0
  73. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +0 -0
  74. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +0 -0
  75. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/logger.py +0 -0
  76. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +0 -0
  77. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/supervoxel/src/supervoxel.py +0 -0
  78. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/trainer/__init__.py +0 -0
  79. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/trainer/nnInteractiveTrainer.py +0 -0
  80. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/utils/__init__.py +0 -0
  81. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/utils/bboxes.py +0 -0
  82. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/utils/checkpoint_cleansing.py +0 -0
  83. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/utils/crop.py +0 -0
  84. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/utils/erosion_dilation.py +0 -0
  85. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/utils/inference_helpers.py +0 -0
  86. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/utils/os_shennanigans.py +0 -0
  87. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive/utils/rounding.py +0 -0
  88. {nninteractive-2.1.0 → nninteractive-2.2.0}/nnInteractive.egg-info/entry_points.txt +0 -0
  89. {nninteractive-2.1.0 → nninteractive-2.2.0}/readme.md +0 -0
  90. {nninteractive-2.1.0 → nninteractive-2.2.0}/setup.cfg +0 -0
  91. {nninteractive-2.1.0 → nninteractive-2.2.0}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nnInteractive
3
- Version: 2.1.0
3
+ Version: 2.2.0
4
4
  Summary: Inference code for nnInteractive
5
5
  Author: Helmholtz Imaging Applied Computer Vision Lab
6
6
  Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
@@ -226,10 +226,14 @@ Requires-Dist: batchgenerators>=0.25.1
226
226
  Requires-Dist: fastapi>=0.110
227
227
  Requires-Dist: uvicorn[standard]>=0.27
228
228
  Requires-Dist: httpx>=0.27
229
+ Requires-Dist: blosc2
229
230
  Provides-Extra: dev
230
231
  Requires-Dist: black; extra == "dev"
231
232
  Requires-Dist: ruff; extra == "dev"
232
233
  Requires-Dist: pre-commit; extra == "dev"
234
+ Provides-Extra: client
235
+ Requires-Dist: httpx>=0.27; extra == "client"
236
+ Requires-Dist: blosc2; extra == "client"
233
237
  Dynamic: license-file
234
238
 
235
239
  <img src="imgs/nnInteractive_header_white.png">
@@ -50,14 +50,16 @@ class nnInteractiveInferenceSession:
50
50
  ):
51
51
  """
52
52
  Only intended to work with nnInteractiveTrainerV2 and its derivatives
53
+
54
+ ``use_torch_compile``: compile the network with ``torch.compile``. The
55
+ first prediction after enabling this is slow (compilation happens lazily
56
+ on the first forward pass), but every subsequent prediction is faster.
57
+ This is recommended for the persistent inference server, where the
58
+ process is long-lived so the one-time compile cost is paid only once and
59
+ amortized across the whole session lifetime.
53
60
  """
54
61
  print("session initialized")
55
62
 
56
- # set as part of initialization
57
- assert use_torch_compile is False, (
58
- "torch.compile is not supported. The blosc2-backed interaction tensor "
59
- "requires numpy↔torch round-trips that break compile tracing."
60
- )
61
63
  self.network = None
62
64
  self.label_manager = None
63
65
  self.dataset_json = None
@@ -883,6 +885,42 @@ class nnInteractiveInferenceSession:
883
885
  else:
884
886
  del initial_seg
885
887
 
888
+ @torch.inference_mode()
889
+ def warmup(self) -> bool:
890
+ """Run a single dummy forward pass to trigger lazy ``torch.compile`` compilation up front.
891
+
892
+ With ``torch.compile`` enabled the network is compiled lazily on its first
893
+ forward pass, which would otherwise make the user's *first* real prediction
894
+ slow. Every prediction path — the initial coarse pass, the zoom-out
895
+ iterations, and the refinement patches — feeds the network an input of
896
+ identical shape ``[1, num_input_channels + num_interaction_channels,
897
+ *patch_size]`` (``_build_network_input`` always resizes the crop to
898
+ ``patch_size``, and refinement crops at exactly ``patch_size``). So a single
899
+ dummy pass at that shape populates the compile cache and every subsequent
900
+ real prediction is fast.
901
+
902
+ Returns ``True`` if a warmup pass was run, ``False`` if it was a no-op
903
+ (network not compiled — there is nothing to pre-compile, so a dummy pass
904
+ would not save the user any time). Mirrors ``_predict``'s autocast/
905
+ inference-mode context and the float32 input dtype that ``torch.cat``
906
+ produces when concatenating the float32 image with the fp16 interactions.
907
+ """
908
+ if self.network is None or self.configuration_manager is None:
909
+ raise RuntimeError("warmup() requires an initialized network; call initialize_* first")
910
+ if not isinstance(self.network, OptimizedModule):
911
+ return False
912
+ num_input_channels = (
913
+ determine_num_input_channels(self.plans_manager, self.configuration_manager, self.dataset_json)
914
+ + self.num_interaction_channels
915
+ )
916
+ patch_size = self.configuration_manager.patch_size
917
+ dummy = torch.zeros((1, num_input_channels, *patch_size), dtype=torch.float32, device=self.device)
918
+ with torch.autocast(self.device.type, enabled=True) if self.device.type == "cuda" else dummy_context():
919
+ self.network(dummy)
920
+ del dummy
921
+ empty_cache(self.device)
922
+ return True
923
+
886
924
  @torch.inference_mode()
887
925
  def _predict(self, force_full_refine: bool = False):
888
926
  """
@@ -1307,9 +1345,16 @@ class nnInteractiveInferenceSession:
1307
1345
 
1308
1346
  Returns an artifact dict that can be applied to this or any other freshly
1309
1347
  constructed session via :meth:`initialize_from_loaded_artifacts`. The
1310
- returned references (network, plans_manager, configuration_manager, ...)
1311
- are safe to share by reference across sessions — they are treated as
1312
- read-only after construction.
1348
+ returned values are the actual objects (the ``nn.Module`` with its
1349
+ weights and buffers, the plans/configuration managers, the dataset
1350
+ json, the label manager) — not copies. Multiple sessions calling
1351
+ :meth:`initialize_from_loaded_artifacts` with the same dict will all
1352
+ end up with ``self.network`` pointing at the same module instance and
1353
+ the same weight tensors on the GPU. This is safe as long as callers
1354
+ treat these objects as read-only after construction; in the multi-
1355
+ session server that is enforced by running inference under
1356
+ ``@torch.inference_mode()`` and serializing predict calls with a
1357
+ global GPU lock.
1313
1358
 
1314
1359
  Note: this also mutates ``self`` (applies capability, sets pad/decay/
1315
1360
  thickness) because ``num_interaction_channels`` is required to build the
@@ -1402,7 +1447,10 @@ class nnInteractiveInferenceSession:
1402
1447
 
1403
1448
  ``artifacts`` is the dict returned by :meth:`_load_model_artifacts_from_disk`.
1404
1449
  Useful for spawning multiple sessions that share one loaded model (e.g.
1405
- the multi-session inference server).
1450
+ the multi-session inference server). All artifact entries — including
1451
+ ``self.network`` — are stored by reference; passing the same dict to
1452
+ multiple sessions does not duplicate the network or its weights in
1453
+ memory.
1406
1454
  """
1407
1455
  self.preferred_scribble_thickness = artifacts["preferred_scribble_thickness"]
1408
1456
  self.interaction_decay = artifacts["interaction_decay"]
@@ -10,6 +10,7 @@ from __future__ import annotations
10
10
 
11
11
  import json
12
12
  import os
13
+ import threading
13
14
  import warnings
14
15
  from typing import List, Optional, Tuple, Union
15
16
 
@@ -165,8 +166,16 @@ class nnInteractiveRemoteInferenceSession:
165
166
  claim_info = claim_resp.json()
166
167
  self._lease_token = claim_info["lease_token"]
167
168
  self.idle_timeout_seconds: float = float(claim_info.get("idle_timeout_seconds", 0.0))
169
+ self.liveness_timeout_seconds: float = float(claim_info.get("liveness_timeout_seconds", 0.0))
168
170
  self._http.headers[LEASE_HEADER] = self._lease_token
169
171
 
172
+ # Background liveness heartbeat bookkeeping. Defined before any code that
173
+ # might raise so close()/__del__ can always reference them safely. The
174
+ # thread itself is started at the end of __init__, once construction has
175
+ # fully succeeded.
176
+ self._stop_heartbeat = threading.Event()
177
+ self._heartbeat_thread: Optional[threading.Thread] = None
178
+
170
179
  caps = self._get_json(PATH_CAPABILITIES)
171
180
 
172
181
  # Attributes that mirror the local session so the GUI can introspect them
@@ -187,6 +196,17 @@ class nnInteractiveRemoteInferenceSession:
187
196
  self.target_buffer: Union[np.ndarray, torch.Tensor, None] = None
188
197
  self.do_autozoom: bool = bool(caps.get("do_autozoom", True))
189
198
 
199
+ # Construction succeeded — start auto-heartbeating to keep the server
200
+ # from reaping us as a dead client. Beat at half the liveness timeout so
201
+ # one dropped request still leaves margin. Daemon thread: it never blocks
202
+ # interpreter exit, and close() joins it cleanly.
203
+ if self.liveness_timeout_seconds > 0:
204
+ self._heartbeat_interval = max(5.0, self.liveness_timeout_seconds / 2.0)
205
+ self._heartbeat_thread = threading.Thread(
206
+ target=self._heartbeat_loop, name="nnInteractive-heartbeat", daemon=True
207
+ )
208
+ self._heartbeat_thread.start()
209
+
190
210
  # ------------------------------------------------------------------ #
191
211
  # HTTP helpers (private) #
192
212
  # ------------------------------------------------------------------ #
@@ -272,7 +292,10 @@ class nnInteractiveRemoteInferenceSession:
272
292
  self.target_buffer = target_buffer
273
293
  self._post_json(
274
294
  PATH_SET_TARGET_BUFFER,
275
- {"shape": list(target_buffer.shape), "dtype": _buffer_dtype_str(target_buffer)},
295
+ {
296
+ "shape": list(target_buffer.shape),
297
+ "dtype": _buffer_dtype_str(target_buffer),
298
+ },
276
299
  )
277
300
 
278
301
  def set_do_autozoom(self, do_autozoom: bool) -> None:
@@ -420,17 +443,44 @@ class nnInteractiveRemoteInferenceSession:
420
443
  return False
421
444
 
422
445
  def heartbeat(self) -> float:
423
- """Extend this session's idle timeout. Returns remaining seconds.
446
+ """Tell the server this client is still alive. Returns remaining seconds
447
+ until the *idle* timeout.
448
+
449
+ This proves liveness only: it stops the server from reaping the session
450
+ as a crashed/dead client, but it does NOT postpone the idle timeout —
451
+ that is refreshed solely by real user actions (``set_image``,
452
+ ``add_*_interaction``, …). A session left untouched will therefore still
453
+ be reaped at the idle timeout even while heartbeats keep arriving.
424
454
 
425
- For GUIs that keep the app open across long idle stretches: drive
426
- this from a timer (e.g. every 60 s) to avoid SessionExpiredError
427
- after the configured idle timeout.
455
+ You normally never call this yourself: the session auto-heartbeats from
456
+ a background thread for the lifetime of the object.
428
457
  """
429
458
  resp = self._http.post(PATH_HEARTBEAT)
430
459
  _raise_for_lease_errors(resp)
431
460
  resp.raise_for_status()
432
461
  return float(resp.json().get("remaining_seconds", 0.0))
433
462
 
463
+ def _heartbeat_loop(self) -> None:
464
+ """Background daemon: prove liveness every ``_heartbeat_interval`` seconds.
465
+
466
+ Stops when the session is closed or once the lease is gone. Transient
467
+ network errors are swallowed so a brief blip doesn't kill the heartbeat;
468
+ the server's liveness timeout tolerates a few missed beats. Lease expiry
469
+ (idle reap or server restart) is surfaced to the user on their next real
470
+ call, not from this thread.
471
+ """
472
+ while not self._stop_heartbeat.wait(self._heartbeat_interval):
473
+ try:
474
+ self.heartbeat()
475
+ except SessionExpiredError:
476
+ break
477
+ except httpx.HTTPError:
478
+ continue
479
+ except Exception:
480
+ # Never let the daemon thread die noisily (e.g. client closing
481
+ # concurrently). Bail out quietly.
482
+ break
483
+
434
484
  def lease_status(self) -> dict:
435
485
  """Read-only probe: how much time is left before this session is reaped?
436
486
 
@@ -444,6 +494,14 @@ class nnInteractiveRemoteInferenceSession:
444
494
  return resp.json()
445
495
 
446
496
  def close(self) -> None:
497
+ # Stop the heartbeat thread first so it can't use self._http after we
498
+ # close it. join() with a short timeout: the thread spends almost all
499
+ # its time in Event.wait(), which the set() interrupts immediately.
500
+ self._stop_heartbeat.set()
501
+ if self._heartbeat_thread is not None:
502
+ self._heartbeat_thread.join(timeout=5.0)
503
+ self._heartbeat_thread = None
504
+
447
505
  # Best-effort release so the server can free our slot for other users
448
506
  # without waiting for the idle reaper. Swallow errors: the server may
449
507
  # already be gone, our lease may already be expired, etc. close()
@@ -2,15 +2,24 @@
2
2
 
3
3
  The server hosts up to ``max_sessions`` concurrent
4
4
  :class:`nnInteractiveInferenceSession` instances, one per connected client. The
5
- model weights are loaded once at startup and shared by reference across all
6
- sessions — per-session state (image, target buffer, interactions tensor) is
7
- isolated.
5
+ model artifacts (``nn.Module`` network with its weights and buffers, plans/
6
+ configuration managers, dataset json, label manager) are loaded once at startup;
7
+ each session's ``self.network`` is a plain Python reference to that single
8
+ module — there is exactly one network and one copy of the weights on the GPU
9
+ regardless of session count. Per-session state (image, target buffer,
10
+ interactions tensor) is isolated. Safety of sharing relies on (a) inference
11
+ running under ``@torch.inference_mode()`` and (b) a global ``gpu_lock``
12
+ serializing predict-capable endpoints, so no two sessions ever touch the
13
+ module concurrently and nothing mutates it after construction.
8
14
 
9
15
  Each client identifies itself via a lease token issued by ``POST /claim``. The
10
16
  token rides along on every subsequent request in the ``X-Lease-Token`` header.
11
- Sessions idle for longer than ``idle_timeout_seconds`` are reaped automatically;
12
- subsequent requests bearing a reaped lease receive HTTP 410 Gone so the client
13
- can surface a "session expired" message.
17
+ Sessions are reaped automatically for either of two reasons: the user went idle
18
+ (no real interaction for longer than ``idle_timeout_seconds``) or the client
19
+ went dead (no request of any kind — not even a heartbeat for longer than the
20
+ much shorter ``liveness_timeout_seconds``, which reclaims slots held by crashed
21
+ clients quickly). Subsequent requests bearing a reaped lease receive HTTP 410
22
+ Gone so the client can surface a "session expired" message.
14
23
 
15
24
  Concurrency model:
16
25
  - Each session has its own ``threading.Lock`` that serializes the per-session
@@ -72,13 +81,28 @@ class SessionEntry:
72
81
  self.session = session
73
82
  self.lock = threading.Lock()
74
83
  self.created_at = time.monotonic()
84
+ # Two independent clocks. ``last_active_at`` tracks real user actions
85
+ # (drives the idle timeout); ``last_seen_at`` tracks any sign of life
86
+ # including heartbeats (drives the much shorter liveness timeout used to
87
+ # reclaim slots held by crashed/disconnected clients).
75
88
  self.last_active_at = self.created_at
89
+ self.last_seen_at = self.created_at
76
90
 
77
- def touch(self) -> None:
78
- self.last_active_at = time.monotonic()
91
+ def mark_seen(self) -> None:
92
+ """Record that the client is still alive (liveness only)."""
93
+ self.last_seen_at = time.monotonic()
94
+
95
+ def mark_active(self) -> None:
96
+ """Record real user activity. Activity implies liveness, so this bumps
97
+ both clocks."""
98
+ now = time.monotonic()
99
+ self.last_active_at = now
100
+ self.last_seen_at = now
79
101
 
80
102
  def close(self) -> None:
81
- """Free the session's per-instance state. Shared model artifacts are NOT freed.
103
+ """Free the session's per-instance state. The shared network module and
104
+ other model artifacts are NOT freed — they live in the registry and are
105
+ reused by future sessions.
82
106
 
83
107
  Best-effort: any exception here is logged but not re-raised; cleanup must
84
108
  not block reaping or shutdown.
@@ -100,8 +124,10 @@ class SessionFull(Exception):
100
124
  class SessionRegistry:
101
125
  """Threadsafe lease-keyed dict of :class:`SessionEntry`.
102
126
 
103
- The model artifacts loaded at server startup are stashed here and reused
104
- whenever a new session is created.
127
+ The model artifacts loaded at server startup are stashed here and handed to
128
+ every newly created session by reference (the ``nn.Module`` instance, its
129
+ weights, and the plans/configuration/label-manager objects are not copied
130
+ or re-instantiated per session).
105
131
  """
106
132
 
107
133
  def __init__(
@@ -109,17 +135,21 @@ class SessionRegistry:
109
135
  artifacts: dict,
110
136
  max_sessions: int,
111
137
  idle_timeout_seconds: float,
138
+ liveness_timeout_seconds: float,
112
139
  device: torch.device,
113
140
  torch_n_threads: int,
114
141
  do_autozoom: bool,
142
+ use_torch_compile: bool,
115
143
  verbose: bool,
116
144
  ) -> None:
117
145
  self._artifacts = artifacts
118
146
  self._max_sessions = int(max_sessions)
119
147
  self._idle_timeout_seconds = float(idle_timeout_seconds)
148
+ self._liveness_timeout_seconds = float(liveness_timeout_seconds)
120
149
  self._device = device
121
150
  self._torch_n_threads = torch_n_threads
122
151
  self._do_autozoom = do_autozoom
152
+ self._use_torch_compile = use_torch_compile
123
153
  self._verbose = verbose
124
154
  self._entries: dict[str, SessionEntry] = {}
125
155
  self._mu = threading.Lock()
@@ -132,6 +162,10 @@ class SessionRegistry:
132
162
  def idle_timeout_seconds(self) -> float:
133
163
  return self._idle_timeout_seconds
134
164
 
165
+ @property
166
+ def liveness_timeout_seconds(self) -> float:
167
+ return self._liveness_timeout_seconds
168
+
135
169
  def claim(self) -> str:
136
170
  """Create a new session and return its lease token. Raises SessionFull if at cap."""
137
171
  with self._mu:
@@ -140,7 +174,7 @@ class SessionRegistry:
140
174
  token = uuid.uuid4().hex
141
175
  session = nnInteractiveInferenceSession(
142
176
  device=self._device,
143
- use_torch_compile=False,
177
+ use_torch_compile=self._use_torch_compile,
144
178
  verbose=self._verbose,
145
179
  torch_n_threads=self._torch_n_threads,
146
180
  do_autozoom=self._do_autozoom,
@@ -148,18 +182,28 @@ class SessionRegistry:
148
182
  session.initialize_from_loaded_artifacts(self._artifacts)
149
183
  entry = SessionEntry(session)
150
184
  self._entries[token] = entry
151
- logger.info("claimed session %s (%d/%d active)", token, len(self._entries), self._max_sessions)
185
+ logger.info(
186
+ "claimed session %s (%d/%d active)",
187
+ token,
188
+ len(self._entries),
189
+ self._max_sessions,
190
+ )
152
191
  return token
153
192
 
154
193
  def get(self, token: Optional[str]) -> SessionEntry:
155
- """Look up a session by lease token, touching last_active_at on success."""
194
+ """Look up a session by lease token, marking it seen (liveness) on success.
195
+
196
+ Note this only refreshes the liveness clock, not activity: a bare
197
+ ``/heartbeat`` keeps the session from being reaped as dead but does not
198
+ postpone the idle timeout. Endpoints that represent real user actions
199
+ call ``entry.mark_active()`` explicitly (see the lock helpers)."""
156
200
  if not token:
157
201
  raise HTTPException(status.HTTP_410_GONE, detail="lease token missing")
158
202
  with self._mu:
159
203
  entry = self._entries.get(token)
160
204
  if entry is None:
161
205
  raise HTTPException(status.HTTP_410_GONE, detail="lease expired or unknown")
162
- entry.touch()
206
+ entry.mark_seen()
163
207
  return entry
164
208
 
165
209
  def peek(self, token: Optional[str]) -> SessionEntry:
@@ -181,21 +225,37 @@ class SessionRegistry:
181
225
  if entry is None:
182
226
  return False
183
227
  entry.close()
184
- logger.info("released session %s (%d/%d active)", token, len(self._entries), self._max_sessions)
228
+ logger.info(
229
+ "released session %s (%d/%d active)",
230
+ token,
231
+ len(self._entries),
232
+ self._max_sessions,
233
+ )
185
234
  return True
186
235
 
187
236
  def sweep(self) -> int:
188
- """Drop sessions idle for more than idle_timeout_seconds. Returns the number reaped."""
237
+ """Drop sessions that have either gone idle or stopped showing signs of life.
238
+
239
+ A session is reaped if it has seen no real user activity for longer than
240
+ ``idle_timeout_seconds`` (the user walked away) OR if it has not been
241
+ seen at all for longer than ``liveness_timeout_seconds`` (the client
242
+ process crashed/disconnected and stopped heartbeating). Returns the
243
+ number reaped."""
189
244
  now = time.monotonic()
190
- reaped: list[tuple[str, SessionEntry]] = []
245
+ reaped: list[tuple[str, SessionEntry, str]] = []
191
246
  with self._mu:
192
247
  for token, entry in list(self._entries.items()):
193
- if (now - entry.last_active_at) > self._idle_timeout_seconds:
194
- self._entries.pop(token, None)
195
- reaped.append((token, entry))
196
- for token, entry in reaped:
248
+ if (now - entry.last_seen_at) > self._liveness_timeout_seconds:
249
+ reason = "dead"
250
+ elif (now - entry.last_active_at) > self._idle_timeout_seconds:
251
+ reason = "idle"
252
+ else:
253
+ continue
254
+ self._entries.pop(token, None)
255
+ reaped.append((token, entry, reason))
256
+ for token, entry, reason in reaped:
197
257
  entry.close()
198
- logger.info("reaped idle session %s", token)
258
+ logger.info("reaped %s session %s", reason, token)
199
259
  return len(reaped)
200
260
 
201
261
  def close_all(self) -> None:
@@ -215,19 +275,23 @@ def make_app(
215
275
  device: torch.device,
216
276
  max_sessions: int = 1,
217
277
  idle_timeout_seconds: float = 600.0,
278
+ liveness_timeout_seconds: float = 60.0,
218
279
  torch_n_threads: int = 8,
219
280
  do_autozoom: bool = True,
281
+ use_torch_compile: bool = False,
220
282
  verbose: bool = False,
221
283
  api_key: Optional[str] = None,
222
- sweep_interval_seconds: float = 30.0,
284
+ sweep_interval_seconds: float = 15.0,
223
285
  ) -> FastAPI:
224
286
  registry = SessionRegistry(
225
287
  artifacts=artifacts,
226
288
  max_sessions=max_sessions,
227
289
  idle_timeout_seconds=idle_timeout_seconds,
290
+ liveness_timeout_seconds=liveness_timeout_seconds,
228
291
  device=device,
229
292
  torch_n_threads=torch_n_threads,
230
293
  do_autozoom=do_autozoom,
294
+ use_torch_compile=use_torch_compile,
231
295
  verbose=verbose,
232
296
  )
233
297
  gpu_lock = threading.Lock()
@@ -266,9 +330,12 @@ def make_app(
266
330
 
267
331
  app = FastAPI(title="nnInteractive Inference Server", lifespan=lifespan)
268
332
 
269
- # Capability snapshot is computed once and never changes (model is loaded
270
- # once at startup and shared by all sessions). We build it from a fresh
271
- # session initialized off the artifacts.
333
+ # Capability snapshot is computed once and never changes (the network
334
+ # module is loaded once at startup and the same instance is referenced by
335
+ # every session). We build it from a fresh session initialized off the
336
+ # artifacts.
337
+ # This throwaway session only reads static capability metadata and never runs
338
+ # a forward pass, so there is no point compiling its network reference.
272
339
  _capability_session = nnInteractiveInferenceSession(
273
340
  device=device,
274
341
  use_torch_compile=False,
@@ -370,7 +437,11 @@ def make_app(
370
437
  )
371
438
 
372
439
  def _under_session_lock(entry: SessionEntry, fn):
373
- """Run ``fn(session)`` under the session's lock, converting known errors to HTTP 400."""
440
+ """Run ``fn(session)`` under the session's lock, converting known errors to HTTP 400.
441
+
442
+ Every endpoint routed through here represents a real user action, so we
443
+ also refresh the activity clock (postponing the idle timeout)."""
444
+ entry.mark_active()
374
445
  with entry.lock:
375
446
  try:
376
447
  return fn(entry.session)
@@ -379,7 +450,10 @@ def make_app(
379
450
 
380
451
  def _under_session_and_gpu_lock(entry: SessionEntry, fn):
381
452
  """Run ``fn(session)`` under session lock + global GPU lock. Acquisition order
382
- is always session-then-gpu to avoid deadlocks."""
453
+ is always session-then-gpu to avoid deadlocks.
454
+
455
+ Like ``_under_session_lock``, this marks real user activity."""
456
+ entry.mark_active()
383
457
  with entry.lock:
384
458
  with gpu_lock:
385
459
  try:
@@ -410,6 +484,7 @@ def make_app(
410
484
  body = {
411
485
  "lease_token": token,
412
486
  "idle_timeout_seconds": registry.idle_timeout_seconds,
487
+ "liveness_timeout_seconds": registry.liveness_timeout_seconds,
413
488
  "max_sessions": registry.max_sessions,
414
489
  }
415
490
  return Response(
@@ -580,7 +655,7 @@ def _build_capability_snapshot(session: nnInteractiveInferenceSession) -> dict:
580
655
  "supports_initial_label": bool(session.supports_initial_label),
581
656
  "supports_zero_shot_label_refinement": bool(session.supports_zero_shot_label_refinement),
582
657
  "preferred_scribble_thickness": session.preferred_scribble_thickness,
583
- "interaction_decay": float(session.interaction_decay) if session.interaction_decay is not None else None,
658
+ "interaction_decay": (float(session.interaction_decay) if session.interaction_decay is not None else None),
584
659
  "patch_size": list(cfg.patch_size) if cfg is not None else None,
585
660
  "do_autozoom": bool(session.do_autozoom),
586
661
  "inference_session_version": session.INFERENCE_SESSION_VERSION,