nnInteractive 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nnInteractive/__init__.py +3 -0
  2. nnInteractive/inference/__init__.py +0 -0
  3. nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
  4. nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
  5. nnInteractive/inference/inference_session.py +1400 -0
  6. nnInteractive/interaction/__init__.py +0 -0
  7. nnInteractive/interaction/point.py +166 -0
  8. nnInteractive/supervoxel/setup.py +4 -0
  9. nnInteractive/supervoxel/src/metadata.py +118 -0
  10. nnInteractive/supervoxel/src/reader.py +175 -0
  11. nnInteractive/supervoxel/src/run.py +136 -0
  12. nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
  13. nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
  14. nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
  15. nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
  16. nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
  17. nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
  18. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
  19. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
  20. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
  21. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
  22. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
  23. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
  24. nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
  25. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
  26. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
  27. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
  28. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
  29. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
  30. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
  31. nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
  32. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
  33. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
  34. nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
  35. nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
  36. nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
  37. nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
  38. nnInteractive/supervoxel/src/sam2/setup.py +174 -0
  39. nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
  40. nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
  41. nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
  42. nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
  43. nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
  44. nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
  45. nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
  46. nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
  47. nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
  48. nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
  49. nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
  50. nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
  51. nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
  52. nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
  53. nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
  54. nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
  55. nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
  56. nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
  57. nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
  58. nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
  59. nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
  60. nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
  61. nnInteractive/supervoxel/src/supervoxel.py +198 -0
  62. nnInteractive/trainer/__init__.py +0 -0
  63. nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
  64. nnInteractive/utils/__init__.py +0 -0
  65. nnInteractive/utils/bboxes.py +217 -0
  66. nnInteractive/utils/checkpoint_cleansing.py +9 -0
  67. nnInteractive/utils/crop.py +268 -0
  68. nnInteractive/utils/erosion_dilation.py +48 -0
  69. nnInteractive/utils/inference_helpers.py +45 -0
  70. nnInteractive/utils/os_shennanigans.py +16 -0
  71. nnInteractive/utils/rounding.py +13 -0
  72. nninteractive-2.0.0.dist-info/METADATA +511 -0
  73. nninteractive-2.0.0.dist-info/RECORD +76 -0
  74. nninteractive-2.0.0.dist-info/WHEEL +5 -0
  75. nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
  76. nninteractive-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1400 @@
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import os
3
+ from os import cpu_count
4
+ from time import time
5
+ from typing import Union, List, Tuple, Optional
6
+ import warnings
7
+
8
+ import blosc2
9
+
10
+ import numpy as np
11
+ import torch
12
+ from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice, crop_and_pad_nd
13
+ from batchgenerators.utilities.file_and_folder_operations import load_json, join, subdirs, isfile
14
+ from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
15
+ from nnunetv2.utilities.helpers import dummy_context, empty_cache
16
+ from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
17
+ from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
18
+ from torch import nn
19
+ from torch._dynamo import OptimizedModule
20
+ from torch.nn.functional import interpolate
21
+
22
+ import nnInteractive
23
+ from nnInteractive.interaction.point import PointInteraction_stub
24
+ from nnInteractive.trainer.nnInteractiveTrainer import nnInteractiveTrainer_stub
25
+ from nnInteractive.utils.bboxes import generate_bounding_boxes
26
+ from nnInteractive.utils.crop import crop_and_pad_into_buffer, paste_tensor, pad_cropped, crop_to_valid
27
+ from nnInteractive.utils.erosion_dilation import iterative_3x3_same_padding_pool3d
28
+ from nnInteractive.utils.inference_helpers import (
29
+ infer_num_interaction_channels_from_mapping,
30
+ parse_channel_pair,
31
+ transform_coordinates_noresampling,
32
+ version_to_tuple,
33
+ )
34
+ from nnInteractive.utils.rounding import round_to_nearest_odd
35
+
36
+
37
+ class nnInteractiveInferenceSession:
38
+ INFERENCE_SESSION_VERSION = nnInteractive.__version__
39
+ REFINEMENT_CACHE_GPU_HEADROOM_BYTES = 4 * 1024**3
40
+ # Interactions implemented by this inference session.
41
+ SUPPORTED_INTERACTION_KEYS = ("scribble", "lasso", "points", "bbox2d", "bbox3d")
42
+
43
+ def __init__(
44
+ self,
45
+ device: torch.device = torch.device("cuda"),
46
+ use_torch_compile: bool = False,
47
+ verbose: bool = False,
48
+ torch_n_threads: int = 8,
49
+ do_autozoom: bool = True,
50
+ ):
51
+ """
52
+ Only intended to work with nnInteractiveTrainerV2 and its derivatives
53
+ """
54
+ print("session initialized")
55
+
56
+ # set as part of initialization
57
+ assert use_torch_compile is False, (
58
+ "torch.compile is not supported. The blosc2-backed interaction tensor "
59
+ "requires numpy↔torch round-trips that break compile tracing."
60
+ )
61
+ self.network = None
62
+ self.label_manager = None
63
+ self.dataset_json = None
64
+ self.trainer_name = None
65
+ self.configuration_manager = None
66
+ self.plans_manager = None
67
+ self._interactions_shape = None
68
+ self.device = device
69
+ self.use_torch_compile = use_torch_compile
70
+ self.interaction_decay = None
71
+ self.current_interaction_intensity: float = 1.0
72
+ self._fp16_max_value = float(torch.finfo(torch.float16).max)
73
+ # Keep renormalized interaction magnitudes around 1/10 of fp16 max to preserve headroom.
74
+ self._interaction_renorm_target = self._fp16_max_value / 10
75
+ self.num_interaction_channels: int = None
76
+ self.supported_interactions: dict = {}
77
+ self.channel_mapping: dict = {}
78
+ self.supports_initial_label: bool = True
79
+ self.supports_zero_shot_label_refinement: bool = True
80
+
81
+ # image specific
82
+ self.interactions = None # blosc2.NDArray once initialized
83
+ self.preprocessed_image: torch.Tensor = None
84
+ self.preprocessed_props = None
85
+ self.target_buffer: Union[np.ndarray, torch.Tensor] = None
86
+
87
+ # this will be set when loading the model (initialize_from_trained_model_folder)
88
+ self.pad_mode_data = self.preferred_scribble_thickness = self.point_interaction = None
89
+
90
+ self.verbose = verbose
91
+
92
+ self.do_autozoom: bool = do_autozoom
93
+
94
+ torch.set_num_threads(min(torch_n_threads, cpu_count()))
95
+ self.torch_n_threads = torch_n_threads
96
+
97
+ self.original_image_shape = None
98
+
99
+ self.new_interaction_zoom_out_factors: List[float] = []
100
+ self.new_interaction_centers = []
101
+ # Create a thread pool executor for background tasks.
102
+ # this only takes care of preprocessing and interaction memory initialization so there is no need to give it
103
+ # more than 2 workers
104
+ self.executor = ThreadPoolExecutor(max_workers=2)
105
+ self.preprocess_future = None
106
+ self.interactions_future = None
107
+
108
+ @staticmethod
109
+ def _is_official_checkpoint(plans: dict, checkpoint: dict) -> bool:
110
+ return (
111
+ plans.get("dataset_name") == "Dataset225_nnInteractiveV2"
112
+ and checkpoint.get("init_args", {}).get("configuration") == "3d_fullres_ps192_bs24"
113
+ )
114
+
115
+ def _legacy_default_capability(self) -> dict:
116
+ return {
117
+ "supported_interactions": {
118
+ "scribble": True,
119
+ "lasso": True,
120
+ "points": True,
121
+ "bbox2d": True,
122
+ "bbox3d": False,
123
+ },
124
+ "supports_initial_label": True,
125
+ "supports_zero_shot_label_refinement": True,
126
+ "interaction_channels": 6,
127
+ "channel_mapping": {
128
+ "prev_seg": 0,
129
+ "bbox2d": (1, 2),
130
+ "bbox3d": (1, 2),
131
+ "lasso": (1, 2),
132
+ "points": (3, 4),
133
+ "scribble": (5, 6),
134
+ },
135
+ }
136
+
137
+ def _to_positive_channel_index(self, idx: int) -> int:
138
+ return idx if idx >= 0 else self.num_interaction_channels + idx
139
+
140
+ def _resolve_channel_pair(self, channel_name: str, override_capability_checks: bool) -> Tuple[int, int]:
141
+ if channel_name in self.channel_mapping:
142
+ return parse_channel_pair(channel_name, self.channel_mapping[channel_name])
143
+ if override_capability_checks:
144
+ warnings.warn(
145
+ f"Interaction '{channel_name}' was forced but no channel mapping exists in capability metadata.",
146
+ RuntimeWarning,
147
+ )
148
+ raise ValueError(f"Interaction '{channel_name}' cannot be executed because no channel mapping was found.")
149
+
150
+ def _is_interaction_supported(self, interaction_name: str) -> bool:
151
+ if interaction_name in self.SUPPORTED_INTERACTION_KEYS:
152
+ return bool(self.supported_interactions.get(interaction_name, False))
153
+ if interaction_name == "initial_label":
154
+ return bool(self.supports_initial_label)
155
+ return False
156
+
157
+ def _get_prev_seg_channel(self) -> int:
158
+ return int(self.channel_mapping["prev_seg"])
159
+
160
+ @staticmethod
161
+ def _clip_bbox_to_shape(bbox: List[List[int]], spatial_shape: Tuple[int, ...]) -> Optional[List[List[int]]]:
162
+ clipped = [[max(0, int(lb)), min(int(ub), int(s))] for (lb, ub), s in zip(bbox, spatial_shape)]
163
+ if any(ub <= lb for lb, ub in clipped):
164
+ return None
165
+ return clipped
166
+
167
+ @staticmethod
168
+ def _bbox_size(bbox: List[List[int]]) -> List[int]:
169
+ return [int(ub - lb) for lb, ub in bbox]
170
+
171
+ @staticmethod
172
+ def _union_bboxes(*bboxes: Optional[List[List[int]]]) -> Optional[List[List[int]]]:
173
+ valid_bboxes = [bbox for bbox in bboxes if bbox is not None]
174
+ if len(valid_bboxes) == 0:
175
+ return None
176
+ return [
177
+ [min(bbox[dim][0] for bbox in valid_bboxes), max(bbox[dim][1] for bbox in valid_bboxes)]
178
+ for dim in range(len(valid_bboxes[0]))
179
+ ]
180
+
181
+ @staticmethod
182
+ def _offset_bboxes(local_bboxes: List[List[List[int]]], offset_bbox: List[List[int]]) -> List[List[List[int]]]:
183
+ return [
184
+ [[lb + offset_bbox[dim][0], ub + offset_bbox[dim][0]] for dim, (lb, ub) in enumerate(bbox)]
185
+ for bbox in local_bboxes
186
+ ]
187
+
188
+ def _interaction_bbox_to_target_bbox(self, bbox: List[List[int]]) -> List[List[int]]:
189
+ return [
190
+ [i[0] + bbc[0], i[1] + bbc[0]] for i, bbc in zip(bbox, self.preprocessed_props["bbox_used_for_cropping"])
191
+ ]
192
+
193
+ def _compute_prev_seg_positive_bbox(self) -> Optional[List[List[int]]]:
194
+ prev_seg_ch = self._get_prev_seg_channel()
195
+ spatial_shape = tuple(int(i) for i in self.interactions.shape[1:])
196
+
197
+ occupancy_x = np.zeros(spatial_shape[0], dtype=bool)
198
+ occupancy_y = np.zeros(spatial_shape[1], dtype=bool)
199
+ occupancy_z = np.zeros(spatial_shape[2], dtype=bool)
200
+ chunk_depth = 64
201
+ for d0 in range(0, spatial_shape[0], chunk_depth):
202
+ d1 = min(spatial_shape[0], d0 + chunk_depth)
203
+ slab = np.asarray(self.interactions[(prev_seg_ch, slice(d0, d1), slice(None), slice(None))]) > 0.5
204
+ if not slab.any():
205
+ continue
206
+ occupancy_x[d0:d1] |= np.any(slab, axis=(1, 2))
207
+ occupancy_y |= np.any(slab, axis=(0, 2))
208
+ occupancy_z |= np.any(slab, axis=(0, 1))
209
+
210
+ occupancies = (occupancy_x, occupancy_y, occupancy_z)
211
+ bbox = []
212
+ for occ in occupancies:
213
+ indices = np.flatnonzero(occ)
214
+ if len(indices) == 0:
215
+ return None
216
+ bbox.append([int(indices[0]), int(indices[-1]) + 1])
217
+ return bbox
218
+
219
+ def _get_dilation_channels_for_resample(self) -> List[int]:
220
+ dilation_channels = set()
221
+ # During zoom-out, point/scribble signals can disappear when area interpolation averages tiny sparse
222
+ # structures away. We therefore dilate only these "thin prompt" channels before resampling.
223
+ for key in ("points", "scribble"):
224
+ if not self.supported_interactions.get(key, False):
225
+ continue
226
+ if key not in self.channel_mapping:
227
+ continue
228
+ pos_ch, neg_ch = parse_channel_pair(key, self.channel_mapping[key])
229
+ dilation_channels.add(pos_ch)
230
+ dilation_channels.add(neg_ch)
231
+ # Use a sorted list so execution is deterministic and easier to reason about in debugging/logging.
232
+ return sorted(dilation_channels)
233
+
234
+ def _check_capability_or_warn(self, interaction_name: str, override_capability_checks: bool):
235
+ if self._is_interaction_supported(interaction_name):
236
+ return
237
+ msg = f"Interaction '{interaction_name}' is not supported by this checkpoint capability metadata."
238
+ if override_capability_checks:
239
+ warnings.warn(f"{msg} Proceeding because override_capability_checks=True.", RuntimeWarning)
240
+ return
241
+ raise ValueError(msg)
242
+
243
+ def _get_non_prev_seg_channels(self) -> List[int]:
244
+ if self.interactions is None:
245
+ return []
246
+ prev_seg_channel = self._get_prev_seg_channel()
247
+ channels = list(range(self.interactions.shape[0]))
248
+ if prev_seg_channel in channels:
249
+ channels.remove(prev_seg_channel)
250
+ return channels
251
+
252
+ def _renormalize_interactions_if_needed(self):
253
+ if self.interactions is None:
254
+ return
255
+ if self.current_interaction_intensity <= self._fp16_max_value:
256
+ return
257
+ channels_to_scale = self._get_non_prev_seg_channels()
258
+ if len(channels_to_scale) == 0:
259
+ self.current_interaction_intensity = min(
260
+ self.current_interaction_intensity, self._interaction_renorm_target
261
+ )
262
+ return
263
+ scale = self._interaction_renorm_target / self.current_interaction_intensity
264
+ for ch in channels_to_scale:
265
+ self.interactions[ch] *= scale
266
+ self.current_interaction_intensity = self._interaction_renorm_target
267
+
268
+ def _interactions_inplace_maximum(self, channel_idx: int, int_slicer, new_values) -> None:
269
+ """In-place element-wise maximum for a subregion of a channel."""
270
+ if isinstance(new_values, torch.Tensor):
271
+ new_values = new_values.cpu().numpy().astype(np.float16)
272
+ full_slicer = (channel_idx, *int_slicer)
273
+ current_sub = np.asarray(self.interactions[full_slicer])
274
+ np.maximum(current_sub, new_values, out=current_sub)
275
+ self.interactions[full_slicer] = current_sub
276
+
277
+ def _write_interactions_channel(self, channel_idx: int, value) -> None:
278
+ """Write a full channel. Handles torch→numpy for blosc2."""
279
+ if isinstance(value, torch.Tensor):
280
+ value = value.cpu().numpy().astype(np.float16)
281
+ self.interactions[channel_idx] = value
282
+
283
+ def _paste_prediction_to_target_buffer(self, prediction: torch.Tensor, bbox: List[List[int]]) -> None:
284
+ target_bbox = self._interaction_bbox_to_target_bbox(bbox)
285
+ if isinstance(self.target_buffer, torch.Tensor):
286
+ pred_for_target = prediction.to(self.target_buffer.device)
287
+ else:
288
+ pred_for_target = prediction.to("cpu")
289
+ paste_tensor(self.target_buffer, pred_for_target, target_bbox)
290
+
291
+ def _estimate_refinement_cache_nbytes(self, cache_bbox: List[List[int]]) -> int:
292
+ cache_voxels = int(np.prod(self._bbox_size(cache_bbox), dtype=np.int64))
293
+ image_nbytes = cache_voxels * torch.empty((), dtype=self.preprocessed_image.dtype).element_size()
294
+ interactions_nbytes = (
295
+ cache_voxels * self.num_interaction_channels * torch.empty((), dtype=torch.float16).element_size()
296
+ )
297
+ return int(image_nbytes + interactions_nbytes)
298
+
299
+ def _select_refinement_cache_device(self, cache_bbox: List[List[int]]) -> torch.device:
300
+ if self.device.type != "cuda":
301
+ return torch.device("cpu")
302
+
303
+ cache_nbytes = self._estimate_refinement_cache_nbytes(cache_bbox)
304
+ try:
305
+ free_mem, _ = torch.cuda.mem_get_info(self.device)
306
+ if free_mem - cache_nbytes >= self.REFINEMENT_CACHE_GPU_HEADROOM_BYTES:
307
+ return self.device
308
+ except Exception:
309
+ pass
310
+
311
+ return torch.device("cpu")
312
+
313
+ def _build_refinement_local_cache(self, bboxes_ordered: List[List[List[int]]]):
314
+ cache_bbox = self._union_bboxes(*bboxes_ordered)
315
+ cache_device = self._select_refinement_cache_device(cache_bbox)
316
+ cache_shape = self._bbox_size(cache_bbox)
317
+ pin_cache = cache_device.type == "cpu" and self.device.type == "cuda"
318
+
319
+ cache_kwargs = {"device": cache_device}
320
+ if pin_cache:
321
+ cache_kwargs["pin_memory"] = True
322
+
323
+ cache_image = torch.zeros(cache_shape, dtype=self.preprocessed_image.dtype, **cache_kwargs)
324
+ cache_interactions = torch.zeros(
325
+ (self.num_interaction_channels, *cache_shape), dtype=torch.float16, **cache_kwargs
326
+ )
327
+
328
+ crop_and_pad_into_buffer(cache_image, cache_bbox, self.preprocessed_image[0])
329
+ crop_and_pad_into_buffer(cache_interactions, cache_bbox, self.interactions)
330
+ self._normalize_interaction_channels_for_network_(cache_interactions)
331
+ return cache_bbox, cache_image, cache_interactions
332
+
333
+ def _prepare_new_interaction_intensity(self):
334
+ if self.interaction_decay is None:
335
+ return
336
+ if not (0 < self.interaction_decay <= 1):
337
+ raise ValueError(f"interaction_decay must be in (0, 1], got {self.interaction_decay}.")
338
+ if self.interaction_decay < 1:
339
+ self.current_interaction_intensity *= 1 / self.interaction_decay
340
+ self._renormalize_interactions_if_needed()
341
+
342
+ def _normalize_interaction_channels_for_network_(self, interaction_tensor: torch.Tensor):
343
+ if interaction_tensor is None or self.current_interaction_intensity == 0:
344
+ return
345
+ if self.current_interaction_intensity == 1:
346
+ return
347
+ prev_seg_channel = self._get_prev_seg_channel()
348
+ for ch in range(interaction_tensor.shape[0]):
349
+ if ch != prev_seg_channel:
350
+ interaction_tensor[ch] /= self.current_interaction_intensity
351
+
352
+ def _load_capability_and_runtime_defaults(self, model_training_output_dir: str):
353
+ capability_file = join(model_training_output_dir, "inference_info.json")
354
+ legacy_file = join(model_training_output_dir, "inference_session_class.json")
355
+
356
+ point_interaction_radius = 4
357
+ preferred_scribble_thickness = [2, 2, 2]
358
+ interaction_decay = 0.98
359
+ pad_mode_data = "constant"
360
+ capability_content = {}
361
+
362
+ # Prefer modern capability metadata; fall back to legacy session metadata for older checkpoints.
363
+ if isfile(capability_file):
364
+ capability_content = load_json(capability_file)
365
+ if not isinstance(capability_content, dict):
366
+ raise RuntimeError(f"Invalid capability metadata in {capability_file}. Expected a JSON object.")
367
+ self._validate_capability_version(capability_content)
368
+ point_interaction_radius = capability_content.get("point_radius", point_interaction_radius)
369
+ preferred_scribble_thickness = capability_content.get(
370
+ "preferred_scribble_thickness", preferred_scribble_thickness
371
+ )
372
+ interaction_decay = capability_content.get("interaction_decay", interaction_decay)
373
+ pad_mode_data = capability_content.get("pad_mode_image", pad_mode_data)
374
+ elif isfile(legacy_file):
375
+ legacy_content = load_json(legacy_file)
376
+ if isinstance(legacy_content, str):
377
+ interaction_decay = 0.9
378
+ else:
379
+ point_interaction_radius = legacy_content.get("point_radius", point_interaction_radius)
380
+ preferred_scribble_thickness = legacy_content.get(
381
+ "preferred_scribble_thickness", preferred_scribble_thickness
382
+ )
383
+ interaction_decay = legacy_content.get("interaction_decay", interaction_decay)
384
+ pad_mode_data = legacy_content.get("pad_mode_image", pad_mode_data)
385
+ else:
386
+ raise FileNotFoundError(
387
+ f"Neither capability metadata ({capability_file}) nor legacy metadata ({legacy_file}) was found."
388
+ )
389
+
390
+ # Accept scalar thickness in metadata for backward compatibility.
391
+ if not isinstance(preferred_scribble_thickness, (tuple, list)):
392
+ preferred_scribble_thickness = [preferred_scribble_thickness] * 3
393
+
394
+ return (
395
+ capability_content,
396
+ point_interaction_radius,
397
+ preferred_scribble_thickness,
398
+ interaction_decay,
399
+ pad_mode_data,
400
+ )
401
+
402
+ def _apply_capability(self, capability: dict):
403
+ default_capability = self._legacy_default_capability()
404
+ default_supported = default_capability["supported_interactions"]
405
+ default_mapping = default_capability["channel_mapping"]
406
+ supported_keys = set(self.SUPPORTED_INTERACTION_KEYS)
407
+ mapping_keys = set(self.SUPPORTED_INTERACTION_KEYS).union({"prev_seg"})
408
+
409
+ raw_supported = capability.get("supported_interactions", {}) if isinstance(capability, dict) else {}
410
+ unknown_supported = set(raw_supported.keys()) - supported_keys
411
+ if len(unknown_supported) > 0:
412
+ raise ValueError(
413
+ f"Capability requests unsupported interactions: {sorted(unknown_supported)}. "
414
+ f"Supported: {sorted(supported_keys)}"
415
+ )
416
+ filtered_supported = {k: bool(v) for k, v in raw_supported.items() if k in supported_keys}
417
+ self.supported_interactions = {**default_supported, **filtered_supported}
418
+ self.supports_initial_label = capability.get("supports_initial_label", True)
419
+ self.supports_zero_shot_label_refinement = capability.get("supports_zero_shot_label_refinement", True)
420
+
421
+ raw_mapping = capability.get("channel_mapping", {}) if isinstance(capability, dict) else {}
422
+ unknown_mapping = set(raw_mapping.keys()) - mapping_keys
423
+ if len(unknown_mapping) > 0:
424
+ raise ValueError(
425
+ f"Capability channel_mapping contains unsupported keys: {sorted(unknown_mapping)}. "
426
+ f"Supported mapping keys: {sorted(mapping_keys)}"
427
+ )
428
+ self.channel_mapping = dict(default_mapping)
429
+ for k, v in raw_mapping.items():
430
+ if k == "prev_seg":
431
+ self.channel_mapping[k] = int(v)
432
+ else:
433
+ self.channel_mapping[k] = parse_channel_pair(k, v)
434
+
435
+ if "interaction_channels" in capability:
436
+ self.num_interaction_channels = int(capability["interaction_channels"]) + 1
437
+ else:
438
+ self.num_interaction_channels = infer_num_interaction_channels_from_mapping(self.channel_mapping)
439
+
440
+ # Normalize all channel indices to positive indexing once at load time so downstream code can
441
+ # use direct indexing without handling negative-offset semantics repeatedly.
442
+ self.channel_mapping["prev_seg"] = self._to_positive_channel_index(int(self.channel_mapping["prev_seg"]))
443
+ for k, v in list(self.channel_mapping.items()):
444
+ if k == "prev_seg":
445
+ continue
446
+ pos_ch, neg_ch = parse_channel_pair(k, v)
447
+ self.channel_mapping[k] = (
448
+ self._to_positive_channel_index(pos_ch),
449
+ self._to_positive_channel_index(neg_ch),
450
+ )
451
+
452
+ def _validate_capability_version(self, capability: dict):
453
+ current_class = self.__class__.__name__
454
+ required_class = capability.get("inference_class", current_class)
455
+ if required_class != current_class:
456
+ raise RuntimeError(
457
+ f"Checkpoint requires inference class '{required_class}', but current class is " f"'{current_class}'."
458
+ )
459
+
460
+ min_version = capability.get("inference_class_min_version")
461
+ if min_version is None:
462
+ return
463
+ if version_to_tuple(min_version) > version_to_tuple(self.INFERENCE_SESSION_VERSION):
464
+ raise RuntimeError(
465
+ f"Checkpoint requires nnInteractiveInferenceSession>={min_version}, but current version is "
466
+ f"{self.INFERENCE_SESSION_VERSION}. Please update nnInteractive."
467
+ )
468
+
469
+ def set_image(self, image: np.ndarray, image_properties: dict = None):
470
+ """
471
+ Image must be 4D to satisfy nnU-Net needs: [c, x, y, z]
472
+ Offload the processing to a background thread.
473
+ """
474
+ if image_properties is None:
475
+ image_properties = {}
476
+ self._reset_session()
477
+ assert image.ndim == 4, f"expected a 4d image as input, got {image.ndim}d. Shape {image.shape}"
478
+ if self.verbose:
479
+ print(f"Initialize with raw image shape {image.shape}")
480
+
481
+ # Offload all image preprocessing to a background thread.
482
+ self.preprocess_future = self.executor.submit(self._background_set_image, image, image_properties)
483
+ self.original_image_shape = image.shape
484
+
485
+ def _finish_preprocessing_and_initialize_interactions(self):
486
+ """
487
+ Block until both the image preprocessing and the interactions tensor initialization
488
+ are finished.
489
+ """
490
+ if self.preprocess_future is not None:
491
+ # Wait for image preprocessing to complete.
492
+ self.preprocess_future.result()
493
+ del self.preprocess_future
494
+ self.preprocess_future = None
495
+
496
+ def set_target_buffer(self, target_buffer: Union[np.ndarray, torch.Tensor]):
497
+ """
498
+ Must be 3d numpy array or torch.Tensor
499
+ """
500
+ self.target_buffer = target_buffer
501
+
502
+ def set_do_autozoom(self, do_autozoom: bool):
503
+ self.do_autozoom = do_autozoom
504
+
505
+ def _reset_session(self):
506
+ self.interactions_future = None
507
+ self.preprocess_future = None
508
+
509
+ del self.preprocessed_image
510
+ del self.target_buffer
511
+ del self.interactions
512
+ del self.preprocessed_props
513
+ self.preprocessed_image = None
514
+ self.target_buffer = None
515
+ self.interactions = None
516
+ self.preprocessed_props = None
517
+ self.current_interaction_intensity = 1.0
518
+ empty_cache(self.device)
519
+ self.original_image_shape = None
520
+
521
+ def _initialize_interactions(self, image_torch: torch.Tensor):
522
+ shape = (self.num_interaction_channels, *image_torch.shape[1:])
523
+ if self.verbose:
524
+ print("Initialize interactions with blosc2 in-memory compression")
525
+ self.interactions = blosc2.zeros(
526
+ shape,
527
+ dtype=np.float16,
528
+ chunks=(1, *[min(64, s) for s in shape[1:]]),
529
+ blocks=(1, *[min(32, s) for s in shape[1:]]),
530
+ cparams={"codec": blosc2.Codec.LZ4, "clevel": 5, "nthreads": min(self.torch_n_threads, os.cpu_count())},
531
+ dparams={"nthreads": 4},
532
+ )
533
+ self._interactions_shape = shape
534
+
535
+ @torch.inference_mode()
536
+ def _background_set_image(self, image: np.ndarray, image_properties: dict):
537
+ # Convert and clone the image tensor.
538
+ image = torch.from_numpy(image.copy())
539
+
540
+ # Crop to nonzero region.
541
+ if self.verbose:
542
+ print("Cropping input image to nonzero region")
543
+ # torch.where eats RAM / VRAM for breakfast. Avoid!!!
544
+ # nonzero_idx = torch.where(image != 0)
545
+ # # Create bounding box: for each dimension, get the min and max (plus one) of the nonzero indices.
546
+ # bbox = [[i.min().item(), i.max().item() + 1] for i in nonzero_idx]
547
+ # del nonzero_idx
548
+ # instead we sum dimensions
549
+ s_x = image.sum(axis=(0, 2, 3), dtype=torch.float)
550
+ wh_x = torch.where(s_x != 0)[0]
551
+ bbox_x = [wh_x.min().item(), wh_x.max().item() + 1]
552
+ del s_x, wh_x
553
+ s_y = image.sum(axis=(0, 1, 3), dtype=torch.float)
554
+ wh_y = torch.where(s_y != 0)[0]
555
+ bbox_y = [wh_y.min().item(), wh_y.max().item() + 1]
556
+ del s_y, wh_y
557
+ s_z = image.sum(axis=(0, 1, 2), dtype=torch.float)
558
+ wh_z = torch.where(s_z != 0)[0]
559
+ bbox_z = [wh_z.min().item(), wh_z.max().item() + 1]
560
+ del s_z, wh_z
561
+ bbox = [[0, 1], bbox_x, bbox_y, bbox_z]
562
+ empty_cache(self.device)
563
+
564
+ slicer = bounding_box_to_slice(bbox) # Assuming this returns a tuple of slices.
565
+ image = image[slicer].float()
566
+ if self.verbose:
567
+ print(f"Cropped image shape: {image.shape}")
568
+
569
+ # As soon as we have the target shape, start initializing the interaction tensor in its own thread.
570
+ self.interactions_future = self.executor.submit(self._initialize_interactions, image)
571
+
572
+ # Normalize the cropped image.
573
+ if self.verbose:
574
+ print("Normalizing cropped image")
575
+ image -= image.mean()
576
+ image /= image.std()
577
+
578
+ self.preprocessed_image = image.to("cpu")
579
+
580
+ self.preprocessed_props = {"bbox_used_for_cropping": bbox[1:]}
581
+
582
+ # we need to wait for this here I believe
583
+ self.interactions_future.result()
584
+ del self.interactions_future
585
+ self.interactions_future = None
586
+
587
+ def reset_interactions(self):
588
+ """
589
+ Use this to reset all interactions and start from scratch for the current image. This includes the initial
590
+ segmentation!
591
+ """
592
+ if self.interactions is not None:
593
+ del self.interactions
594
+ self.interactions = blosc2.zeros(
595
+ self._interactions_shape,
596
+ dtype=np.float16,
597
+ chunks=(1, *[min(64, s) for s in self._interactions_shape[1:]]),
598
+ blocks=(1, *[min(32, s) for s in self._interactions_shape[1:]]),
599
+ cparams={"codec": blosc2.Codec.LZ4, "clevel": 5, "nthreads": os.cpu_count()},
600
+ dparams={"nthreads": 4},
601
+ )
602
+ self.current_interaction_intensity = 1.0
603
+
604
+ if self.target_buffer is not None:
605
+ if isinstance(self.target_buffer, np.ndarray):
606
+ self.target_buffer.fill(0)
607
+ elif isinstance(self.target_buffer, torch.Tensor):
608
+ self.target_buffer.zero_()
609
+ empty_cache(self.device)
610
+
611
+ def add_bbox_interaction(
612
+ self,
613
+ bbox_coords,
614
+ include_interaction: bool,
615
+ run_prediction: bool = True,
616
+ override_capability_checks: bool = False,
617
+ ):
618
+ self._finish_preprocessing_and_initialize_interactions()
619
+ # sanity check
620
+ raw_bbox_size = [i[1] - i[0] for i in bbox_coords]
621
+ if any([i == 0 for i in raw_bbox_size]):
622
+ raise ValueError(f"Given bounding box size is zero in at least one dimension: {bbox_coords}")
623
+
624
+ # capability check
625
+ dims_with_size_one = sum(i == 1 for i in raw_bbox_size)
626
+ # if we do not support 3D bboxes we need to reject 3D bboxes!
627
+ if not self._is_interaction_supported("bbox3d") and dims_with_size_one == 0:
628
+ raise ValueError(
629
+ f"The given bounding box {bbox_coords} has size {raw_bbox_size} indicating a 3D "
630
+ f"bounding box. This is not supported by the loaded model checkpoint."
631
+ )
632
+ # a 2D bounding box is in principle a 3D box as well. Since 2D bboxes work better, we prefer to use a given
633
+ # bbox as 2d if possible (sized 1 in at least one dim and bbox2d supported)
634
+ bbox_kind = "bbox2d" if (dims_with_size_one >= 1 and self._is_interaction_supported("bbox2d")) else "bbox3d"
635
+ self._check_capability_or_warn(bbox_kind, override_capability_checks)
636
+ bbox_pos_channel, bbox_neg_channel = self._resolve_channel_pair(bbox_kind, override_capability_checks)
637
+
638
+ # Convert user-space coordinates (original image) to the cropped nnU-Net internal space.
639
+ lbs_transformed = [
640
+ round(i)
641
+ for i in transform_coordinates_noresampling(
642
+ [i[0] for i in bbox_coords], self.preprocessed_props["bbox_used_for_cropping"]
643
+ )
644
+ ]
645
+ ubs_transformed = [
646
+ round(i)
647
+ for i in transform_coordinates_noresampling(
648
+ [i[1] for i in bbox_coords], self.preprocessed_props["bbox_used_for_cropping"]
649
+ )
650
+ ]
651
+ transformed_bbox_coordinates = [[i, j] for i, j in zip(lbs_transformed, ubs_transformed)]
652
+
653
+ if self.verbose:
654
+ print(
655
+ f"Adding bounding box coordinates.\n"
656
+ f"Raw: {bbox_coords}\n"
657
+ f"Transformed: {transformed_bbox_coordinates}\n"
658
+ f"Crop Bbox: {self.preprocessed_props['bbox_used_for_cropping']}"
659
+ )
660
+
661
+ # Clip bbox to valid interaction volume and guarantee at least one voxel extent per axis.
662
+ image_shape = self.preprocessed_image.shape # Assuming shape is (C, H, W, D) or similar
663
+
664
+ for dim in range(len(transformed_bbox_coordinates)):
665
+ transformed_start, transformed_end = transformed_bbox_coordinates[dim]
666
+
667
+ # Clip to image boundaries
668
+ transformed_start = max(0, transformed_start)
669
+ transformed_end = min(image_shape[dim + 1], transformed_end) # +1 to skip channel dim
670
+
671
+ # Ensure the bounding box does not collapse to a single point
672
+ if transformed_end <= transformed_start:
673
+ if transformed_start == 0:
674
+ transformed_end = min(1, image_shape[dim + 1])
675
+ else:
676
+ transformed_start = max(transformed_start - 1, 0)
677
+
678
+ transformed_bbox_coordinates[dim] = [transformed_start, transformed_end]
679
+
680
+ if self.verbose:
681
+ print(
682
+ f"Bbox coordinates after clip to image boundaries and preventing dim collapse:\n"
683
+ f"Bbox: {transformed_bbox_coordinates}\n"
684
+ f"Internal image shape: {self.preprocessed_image.shape}"
685
+ )
686
+
687
+ self._add_patch_for_bbox_interaction(transformed_bbox_coordinates)
688
+
689
+ self._prepare_new_interaction_intensity()
690
+
691
+ # place bbox
692
+ slicer = tuple([slice(*i) for i in transformed_bbox_coordinates])
693
+ channel = bbox_pos_channel if include_interaction else bbox_neg_channel
694
+ self.interactions[(channel, *slicer)] = self.current_interaction_intensity
695
+
696
+ if run_prediction:
697
+ self._predict()
698
+
699
+ def add_point_interaction(
700
+ self,
701
+ coordinates: Tuple[int, ...],
702
+ include_interaction: bool,
703
+ run_prediction: bool = True,
704
+ override_capability_checks: bool = False,
705
+ ):
706
+ self._check_capability_or_warn("points", override_capability_checks)
707
+ point_pos_channel, point_neg_channel = self._resolve_channel_pair("points", override_capability_checks)
708
+ self._finish_preprocessing_and_initialize_interactions()
709
+
710
+ transformed_coordinates = [
711
+ round(i)
712
+ for i in transform_coordinates_noresampling(coordinates, self.preprocessed_props["bbox_used_for_cropping"])
713
+ ]
714
+
715
+ self._add_patch_for_point_interaction(transformed_coordinates)
716
+
717
+ self._prepare_new_interaction_intensity()
718
+
719
+ interaction_channel = point_pos_channel if include_interaction else point_neg_channel
720
+ self.point_interaction.place_point(
721
+ transformed_coordinates,
722
+ self.interactions,
723
+ channel_idx=interaction_channel,
724
+ intensity_scale=self.current_interaction_intensity,
725
+ )
726
+ if run_prediction:
727
+ self._predict()
728
+
729
+ def _add_image_interaction(
730
+ self,
731
+ image: np.ndarray,
732
+ interaction_channel: int,
733
+ run_prediction: bool,
734
+ interaction_bbox: Optional[List[List[int]]],
735
+ patch_fn,
736
+ ):
737
+ if interaction_bbox is None:
738
+ interaction_bbox = [[0, s] for s in self.original_image_shape[1:]]
739
+
740
+ assert len(interaction_bbox) == 3
741
+ bbox_size = [ub - lb for lb, ub in interaction_bbox]
742
+ assert all(s > 0 for s in bbox_size), "each dimension of interaction_bbox must have positive size"
743
+ assert (
744
+ list(image.shape) == bbox_size
745
+ ), f"image shape {list(image.shape)} must match interaction_bbox size {bbox_size}"
746
+ assert all(
747
+ lb >= 0 and ub <= orig_dim for (lb, ub), orig_dim in zip(interaction_bbox, self.original_image_shape[1:])
748
+ ), f"interaction_bbox {interaction_bbox} exceeds original image bounds {list(self.original_image_shape[1:])}"
749
+
750
+ self._finish_preprocessing_and_initialize_interactions()
751
+
752
+ lbs_internal = [
753
+ round(i)
754
+ for i in transform_coordinates_noresampling(
755
+ [ib[0] for ib in interaction_bbox], self.preprocessed_props["bbox_used_for_cropping"]
756
+ )
757
+ ]
758
+ ubs_internal = [
759
+ round(i)
760
+ for i in transform_coordinates_noresampling(
761
+ [ib[1] for ib in interaction_bbox], self.preprocessed_props["bbox_used_for_cropping"]
762
+ )
763
+ ]
764
+
765
+ image_t = torch.from_numpy(image)
766
+ patch_fn(image_t, offset=lbs_internal)
767
+
768
+ self._prepare_new_interaction_intensity()
769
+
770
+ interaction_shape = self.interactions.shape[1:]
771
+ # Map possibly out-of-bounds transformed bbox to overlapping source/target slices so we only
772
+ # materialize and write the intersecting subregion.
773
+ clipped_lb = [max(0, lb) for lb in lbs_internal]
774
+ clipped_ub = [min(ub, s) for ub, s in zip(ubs_internal, interaction_shape)]
775
+ src_lb = [cl - lb for cl, lb in zip(clipped_lb, lbs_internal)]
776
+ src_ub = [src_lb[d] + (clipped_ub[d] - clipped_lb[d]) for d in range(3)]
777
+ int_slicer = tuple(slice(a, b) for a, b in zip(clipped_lb, clipped_ub))
778
+ src_slicer = tuple(slice(a, b) for a, b in zip(src_lb, src_ub))
779
+ new_values = image_t[src_slicer].cpu().numpy()
780
+ if self.current_interaction_intensity != 1:
781
+ new_values = new_values * self.current_interaction_intensity
782
+ new_values = new_values.astype(np.float16)
783
+ self._interactions_inplace_maximum(interaction_channel, int_slicer, new_values)
784
+ del new_values
785
+ del image_t
786
+ empty_cache(self.device)
787
+
788
+ if run_prediction:
789
+ self._predict()
790
+
791
+ def _add_mask_interaction(
792
+ self,
793
+ interaction_name: str,
794
+ mask_image: np.ndarray,
795
+ include_interaction: bool,
796
+ run_prediction: bool,
797
+ override_capability_checks: bool,
798
+ interaction_bbox: Optional[List[List[int]]],
799
+ ) -> None:
800
+ if self.verbose:
801
+ print(f"Add new {interaction_name} of shape {mask_image.shape} and bbox {interaction_bbox}")
802
+ self._check_capability_or_warn(interaction_name, override_capability_checks)
803
+ pos_channel, neg_channel = self._resolve_channel_pair(interaction_name, override_capability_checks)
804
+ self._add_image_interaction(
805
+ mask_image,
806
+ pos_channel if include_interaction else neg_channel,
807
+ run_prediction,
808
+ interaction_bbox,
809
+ self._generic_add_patch_from_image,
810
+ )
811
+
812
+ def add_scribble_interaction(
813
+ self,
814
+ scribble_image: np.ndarray,
815
+ include_interaction: bool,
816
+ run_prediction: bool = True,
817
+ override_capability_checks: bool = False,
818
+ interaction_bbox: Optional[List[List[int]]] = None,
819
+ ):
820
+ self._add_mask_interaction(
821
+ "scribble",
822
+ scribble_image,
823
+ include_interaction,
824
+ run_prediction,
825
+ override_capability_checks,
826
+ interaction_bbox,
827
+ )
828
+
829
+ def add_lasso_interaction(
830
+ self,
831
+ lasso_image: np.ndarray,
832
+ include_interaction: bool,
833
+ run_prediction: bool = True,
834
+ override_capability_checks: bool = False,
835
+ interaction_bbox: Optional[List[List[int]]] = None,
836
+ ):
837
+ self._add_mask_interaction(
838
+ "lasso", lasso_image, include_interaction, run_prediction, override_capability_checks, interaction_bbox
839
+ )
840
+
841
+ def add_initial_seg_interaction(
842
+ self, initial_seg: np.ndarray, run_prediction: bool = False, override_capability_checks: bool = False
843
+ ):
844
+ """
845
+ WARNING THIS WILL RESET INTERACTIONS!
846
+ """
847
+ self._check_capability_or_warn("initial_label", override_capability_checks)
848
+ assert all(
849
+ [i == j for i, j in zip(self.original_image_shape[1:], initial_seg.shape)]
850
+ ), f"Given initial seg must match input image shape. Input image was: {self.original_image_shape[1:]}, given: {initial_seg.shape}"
851
+
852
+ self._finish_preprocessing_and_initialize_interactions()
853
+
854
+ self.reset_interactions()
855
+
856
+ if isinstance(self.target_buffer, np.ndarray):
857
+ self.target_buffer[:] = initial_seg
858
+
859
+ initial_seg = torch.from_numpy(initial_seg)
860
+
861
+ if isinstance(self.target_buffer, torch.Tensor):
862
+ self.target_buffer[:] = initial_seg
863
+
864
+ # crop (as in preprocessing)
865
+ initial_seg = crop_and_pad_nd(initial_seg, self.preprocessed_props["bbox_used_for_cropping"])
866
+
867
+ # initial seg is written into initial seg buffer
868
+ interaction_channel = self._get_prev_seg_channel()
869
+ self._write_interactions_channel(interaction_channel, initial_seg)
870
+
871
+ empty_cache(self.device)
872
+ if run_prediction:
873
+ self._add_patch_for_initial_seg_interaction(initial_seg)
874
+ del initial_seg
875
+ self._predict(force_full_refine=True)
876
+ else:
877
+ del initial_seg
878
+
879
+ @torch.inference_mode()
880
+ def _predict(self, force_full_refine: bool = False):
881
+ """
882
+ force_full_refine if True we run the refinement over the whole current prediction and not just the diff map.
883
+ More effort but sometimes needed (refine initial seg)
884
+
885
+ If it feels like we are excessively transferring tensors between CPU and GPU, this is deliberate.
886
+ Our goal is to keep this tool usable even for people with smaller GPUs (8-10GB VRAM). In an ideal world
887
+ everyone would have 24GB+ of VRAM and all tensors would like on GPU all the time.
888
+ The amount of hours spent optimizing this function is substantial. Almost every line was turned and twisted
889
+ multiple times. If something appears odd, it is probably so for a reason. Don't change things all willy nilly
890
+ without first understanding what is going on. And don't make changes without verifying that the run time or
891
+ VRAM consumption is not adversely affected.
892
+
893
+ Returns:
894
+
895
+ """
896
+ print("Current cratio", self.interactions.cratio)
897
+
898
+ assert self.pad_mode_data == "constant", "pad modes other than constant are not implemented here"
899
+ assert len(self.new_interaction_centers) == len(self.new_interaction_zoom_out_factors)
900
+ prev_seg_channel = self._get_prev_seg_channel()
901
+ if len(self.new_interaction_centers) == 0:
902
+ print("No patch queued for prediction. Nothing to do.")
903
+ return
904
+
905
+ if len(self.new_interaction_centers) > 1:
906
+ print(
907
+ "It seems like more than one interaction was added since the last prediction. This is not "
908
+ "recommended and may cause unexpected behavior or inefficient predictions\n"
909
+ "!!!WE NO LONGER RUN ONE PREDICTION PER CENTER AND ONLY USE THE LAST ADDED INTERACTION AS CENTER!!!"
910
+ )
911
+ prediction_center, zoom_out_factor = self.new_interaction_centers[-1], self.new_interaction_zoom_out_factors[-1]
912
+ zoom_out_factor = min(4, zoom_out_factor)
913
+
914
+ start_predict = time()
915
+ with torch.autocast(self.device.type, enabled=True) if self.device.type == "cuda" else dummy_context():
916
+ # make a prediction at zoom_out_factor, remember max_zoom_out_factor
917
+ start_initial_pred = time()
918
+ input_for_predict, scaled_patch_size, scaled_bbox, previous_prediction = self._build_network_input(
919
+ prediction_center, zoom_out_factor
920
+ )
921
+ pred = self.network(input_for_predict[None])[0].argmax(0).detach()
922
+ del input_for_predict
923
+
924
+ # detect changes at border. If there are, we enter autozoom
925
+ has_change = self._detect_change_at_border(pred, previous_prediction)
926
+ del previous_prediction
927
+ empty_cache(self.device)
928
+
929
+ print(
930
+ f"Took {round(time() - start_initial_pred, 3)} s for initial prediction at zoom out factor {zoom_out_factor}"
931
+ )
932
+
933
+ # maybe do zoom out
934
+ zoom_out_growth_factor = 1.5
935
+ start_zoomout = time()
936
+ while has_change and self.do_autozoom:
937
+ print(f"AutoZoom zoom out factor {zoom_out_factor}")
938
+ # we allow a max zoom out of 4
939
+ if zoom_out_factor >= 4:
940
+ break
941
+ else:
942
+ zoom_out_factor *= zoom_out_growth_factor
943
+ zoom_out_factor = min(4, zoom_out_factor)
944
+
945
+ input_for_predict, scaled_patch_size, scaled_bbox, previous_prediction_resized = (
946
+ self._build_network_input(prediction_center, zoom_out_factor)
947
+ )
948
+ pred = self.network(input_for_predict[None])[0].argmax(0).detach()
949
+ del input_for_predict
950
+ empty_cache(self.device)
951
+
952
+ has_change = self._detect_change_at_border(pred, previous_prediction_resized)
953
+
954
+ if zoom_out_factor > 1:
955
+ print(f"Zoom out took {round(time() - start_zoomout, 3)} s, max zoom out factor {zoom_out_factor}")
956
+ else:
957
+ print("No zoom out necessary")
958
+
959
+ if zoom_out_factor == 1:
960
+ # simply place pred in the prev_seg channel and target buffer
961
+ paste_tensor(self.interactions, pred.half(), scaled_bbox, channel_idx=prev_seg_channel)
962
+ self._paste_prediction_to_target_buffer(pred, scaled_bbox)
963
+ print("No refinement necessary")
964
+ else:
965
+ # do refinement
966
+
967
+ if not all([i == j for i, j in zip(pred.shape, scaled_patch_size)]):
968
+ pred = (
969
+ interpolate(pred[None, None].to(torch.float32), scaled_patch_size, mode="trilinear")[0, 0]
970
+ >= 0.5
971
+ ).to(torch.uint8)
972
+
973
+ refinement_bboxes = self._plan_refinement_bboxes(pred, scaled_bbox, force_full_refine)
974
+
975
+ # Place the coarse segmentation into prev_seg before refinement
976
+ paste_tensor(self.interactions, pred, scaled_bbox, channel_idx=prev_seg_channel)
977
+
978
+ self._refine_coarse(refinement_bboxes)
979
+
980
+ print(f"Done. Total time {round(time() - start_predict, 3)}s")
981
+
982
+ self.new_interaction_centers = []
983
+ self.new_interaction_zoom_out_factors = []
984
+ empty_cache(self.device)
985
+
986
+ def _build_network_input(self, prediction_center, zoom_out_factor):
987
+ scaled_patch_size = [round(i * zoom_out_factor) for i in self.configuration_manager.patch_size]
988
+ scaled_bbox = [[c - p // 2, c + p // 2 + p % 2] for c, p in zip(prediction_center, scaled_patch_size)]
989
+ prev_seg_channel = self._get_prev_seg_channel()
990
+
991
+ # cropping happens on CPU, padding happens on GPU (later)
992
+ crop_img, pad_image = crop_to_valid(self.preprocessed_image, scaled_bbox)
993
+ interactions_tensor, pad_interaction = crop_to_valid(self.interactions, scaled_bbox)
994
+ # For blosc2, crop_to_valid returns a numpy array; convert to torch (still on CPU).
995
+ if not isinstance(interactions_tensor, torch.Tensor):
996
+ interactions_tensor = torch.from_numpy(np.asarray(interactions_tensor))
997
+
998
+ previous_prediction = interactions_tensor[prev_seg_channel : prev_seg_channel + 1]
999
+
1000
+ # resize input_for_predict (which may be larger than patch size) to patch size
1001
+ # this implementation may not seem straightforward but it does save VRAM which is crucial here
1002
+ if not all([i == j for i, j in zip(self.configuration_manager.patch_size, scaled_patch_size)]):
1003
+ patch_size = self.configuration_manager.patch_size
1004
+ max_pool_ks = round_to_nearest_odd(zoom_out_factor * 2 - 1)
1005
+ dilation_channels = set(self._get_dilation_channels_for_resample()) if max_pool_ks > 1 else set()
1006
+ needs_pad_interaction = any(x for pair in pad_interaction for x in pair)
1007
+
1008
+ previous_prediction = previous_prediction.to(self.device, non_blocking=True)
1009
+ if needs_pad_interaction:
1010
+ previous_prediction = pad_cropped(previous_prediction, pad_interaction)
1011
+ previous_prediction = interpolate(previous_prediction[None], patch_size, mode="nearest")[0, 0]
1012
+
1013
+ # Process interaction channels one at a time to avoid materialising the full
1014
+ # [num_ch, scaled_patch_size³] tensor on GPU. Peak VRAM ≈ one channel at scaled size.
1015
+ num_interaction_ch = interactions_tensor.shape[0]
1016
+ interactions_out = torch.empty(
1017
+ [num_interaction_ch, *patch_size], dtype=interactions_tensor.dtype, device=self.device
1018
+ )
1019
+ for i in range(num_interaction_ch):
1020
+ ch = interactions_tensor[i : i + 1].to(self.device, non_blocking=True)
1021
+ if needs_pad_interaction:
1022
+ ch = pad_cropped(ch, pad_interaction)
1023
+ if i in dilation_channels:
1024
+ ch = iterative_3x3_same_padding_pool3d(ch[None], max_pool_ks)[0]
1025
+ interactions_out[i : i + 1] = interpolate(ch[None], patch_size, mode="area")[0]
1026
+ del ch
1027
+ del interactions_tensor
1028
+ interactions_tensor = interactions_out
1029
+
1030
+ # Keep image and interaction tensors in identical spatial frames before concatenation.
1031
+ # Interactions use area downsampling (with selective dilation beforehand), image uses trilinear.
1032
+ crop_img = crop_img.to(self.device, non_blocking=True)
1033
+ if any(x for pair in pad_image for x in pair):
1034
+ crop_img = pad_cropped(crop_img, pad_image)
1035
+ crop_img = interpolate(crop_img[None], patch_size, mode="trilinear")[0]
1036
+
1037
+ empty_cache(self.device)
1038
+ else:
1039
+ # zoom_out_factor == 1: transfer both tensors to GPU, then pad if needed
1040
+ crop_img = crop_img.to(self.device, non_blocking=True)
1041
+ interactions_tensor = interactions_tensor.to(self.device, non_blocking=True)
1042
+ previous_prediction = previous_prediction.to(self.device, non_blocking=True)
1043
+ if any(x for pair in pad_image for x in pair):
1044
+ crop_img = pad_cropped(crop_img, pad_image)
1045
+ if any(x for pair in pad_interaction for x in pair):
1046
+ interactions_tensor = pad_cropped(interactions_tensor, pad_interaction)
1047
+ previous_prediction = pad_cropped(previous_prediction, pad_interaction)
1048
+ previous_prediction = previous_prediction[0]
1049
+
1050
+ self._normalize_interaction_channels_for_network_(interactions_tensor)
1051
+ input_for_predict = torch.cat((crop_img, interactions_tensor))
1052
+ del crop_img, interactions_tensor
1053
+ empty_cache(self.device)
1054
+ return input_for_predict, scaled_patch_size, scaled_bbox, previous_prediction
1055
+
1056
+ def _refine_coarse(self, bboxes_ordered: List[List[List[int]]]):
1057
+ start_refinement = time()
1058
+ prev_seg_channel = self._get_prev_seg_channel()
1059
+
1060
+ if self.verbose:
1061
+ print(f"Using {len(bboxes_ordered)} bounding boxes for refinement")
1062
+
1063
+ self._refine_coarse_with_local_cache(bboxes_ordered, prev_seg_channel)
1064
+ end_refinement = time()
1065
+ print(
1066
+ f"Took {round(end_refinement - start_refinement, 3)} s for refining the segmentation with {len(bboxes_ordered)} bounding boxes"
1067
+ )
1068
+
1069
+ def _refine_coarse_with_local_cache(self, bboxes_ordered: List[List[List[int]]], prev_seg_channel: int) -> None:
1070
+ cache_bbox, cache_image, cache_interactions = self._build_refinement_local_cache(bboxes_ordered)
1071
+
1072
+ for refinement_bbox in bboxes_ordered:
1073
+ local_bbox = [
1074
+ [lb - cache_dim[0], ub - cache_dim[0]] for (lb, ub), cache_dim in zip(refinement_bbox, cache_bbox)
1075
+ ]
1076
+ spatial_slicer = tuple(slice(lb, ub) for lb, ub in local_bbox)
1077
+ image_patch = cache_image[spatial_slicer][None]
1078
+ interactions_patch = cache_interactions[(slice(None), *spatial_slicer)]
1079
+ if cache_image.device == self.device:
1080
+ patch = torch.cat((image_patch, interactions_patch), dim=0)
1081
+ else:
1082
+ patch = torch.cat(
1083
+ (
1084
+ image_patch.to(self.device, non_blocking=(self.device.type == "cuda")),
1085
+ interactions_patch.to(self.device, non_blocking=(self.device.type == "cuda")),
1086
+ ),
1087
+ dim=0,
1088
+ )
1089
+
1090
+ pred = self.network(patch[None])[0].argmax(0).detach()
1091
+ paste_tensor(
1092
+ cache_interactions,
1093
+ pred.to(cache_interactions.device, dtype=cache_interactions.dtype),
1094
+ local_bbox,
1095
+ channel_idx=prev_seg_channel,
1096
+ )
1097
+ del image_patch, interactions_patch, patch
1098
+ del pred
1099
+
1100
+ final_prev_seg = cache_interactions[prev_seg_channel]
1101
+ paste_tensor(self.interactions, final_prev_seg, cache_bbox, channel_idx=prev_seg_channel)
1102
+ self._paste_prediction_to_target_buffer(final_prev_seg, cache_bbox)
1103
+
1104
+ del cache_image, cache_interactions, final_prev_seg
1105
+ empty_cache(self.device)
1106
+
1107
+ def _detect_change_at_border(
1108
+ self,
1109
+ pred: torch.Tensor,
1110
+ prev_pred: torch.Tensor,
1111
+ abs_pxl_change_threshold=1500,
1112
+ rel_pxl_change_threshold=0.2,
1113
+ min_pxl_change_threshold=100,
1114
+ ):
1115
+ has_change: bool = False
1116
+ for dim in range(pred.ndim):
1117
+ if has_change:
1118
+ break
1119
+ for idx in [0, pred.shape[dim] - 1]:
1120
+ slice_prev = prev_pred.index_select(dim, torch.tensor(idx, device=prev_pred.device))
1121
+ slice_curr = pred.index_select(dim, torch.tensor(idx, device=self.device)).to(prev_pred.device)
1122
+ pixels_prev = torch.sum(slice_prev)
1123
+ pixels_current = torch.sum(slice_curr)
1124
+ pixels_diff = torch.sum(slice_prev != slice_curr)
1125
+ rel_change = max(pixels_prev, pixels_current) / max(min(pixels_prev, pixels_current), 1e-5) - 1
1126
+ if pixels_diff > abs_pxl_change_threshold:
1127
+ has_change = True
1128
+ if self.verbose:
1129
+ print(
1130
+ f"continue zooming because change at borders of {pixels_diff} > {abs_pxl_change_threshold}"
1131
+ )
1132
+ break
1133
+ if pixels_diff > min_pxl_change_threshold and rel_change > rel_pxl_change_threshold:
1134
+ has_change = True
1135
+ if self.verbose:
1136
+ print(
1137
+ f"continue zooming because relative change of {rel_change} > {rel_pxl_change_threshold} and n_pixels {pixels_diff} > {min_pxl_change_threshold}"
1138
+ )
1139
+ break
1140
+ del slice_prev, slice_curr, pixels_prev, pixels_current, pixels_diff
1141
+ return has_change
1142
+
1143
+ def _compute_local_diff_map(
1144
+ self, pred: torch.Tensor, scaled_bbox: List[List[int]], planning_bbox: List[List[int]]
1145
+ ) -> torch.Tensor:
1146
+ """
1147
+ Compute a local diff map inside planning_bbox only.
1148
+
1149
+ pred is expected to be the coarse prediction resized to match scaled_bbox.
1150
+ planning_bbox is in global interaction coordinates and may be larger than scaled_bbox when
1151
+ force_full_refine expands the refinement planning ROI.
1152
+ """
1153
+ prev_seg_ch = self._get_prev_seg_channel()
1154
+ spatial_shape = tuple(int(i) for i in self.interactions.shape[1:])
1155
+ seen_bbox = self._clip_bbox_to_shape(scaled_bbox, spatial_shape)
1156
+ planning_bbox = self._clip_bbox_to_shape(planning_bbox, spatial_shape)
1157
+ if seen_bbox is None or planning_bbox is None:
1158
+ return torch.zeros((0, 0, 0), device=self.device, dtype=torch.uint8)
1159
+
1160
+ local_shape = self._bbox_size(planning_bbox)
1161
+ diff_local = torch.zeros(local_shape, device=self.device, dtype=torch.float16)
1162
+
1163
+ pred_bbox = [
1164
+ [seen_dim[0] - scaled_dim[0], seen_dim[1] - scaled_dim[0]]
1165
+ for seen_dim, scaled_dim in zip(seen_bbox, scaled_bbox)
1166
+ ]
1167
+ pred_bbox = [[max(0, lb), min(ub, int(pred.shape[dim]))] for dim, (lb, ub) in enumerate(pred_bbox)]
1168
+ local_seen_bbox = [
1169
+ [seen_dim[0] - planning_dim[0], seen_dim[1] - planning_dim[0]]
1170
+ for seen_dim, planning_dim in zip(seen_bbox, planning_bbox)
1171
+ ]
1172
+
1173
+ seen_slicer = tuple(slice(lb, ub) for lb, ub in seen_bbox)
1174
+ pred_slicer = tuple(slice(lb, ub) for lb, ub in pred_bbox)
1175
+ local_slicer = tuple(slice(lb, ub) for lb, ub in local_seen_bbox)
1176
+
1177
+ prev_sub = torch.from_numpy(np.asarray(self.interactions[(prev_seg_ch, *seen_slicer)])).to(self.device)
1178
+
1179
+ diff_local[local_slicer] = (pred[pred_slicer] != prev_sub).to(diff_local.dtype)
1180
+ del prev_sub
1181
+
1182
+ # Open/close the local difference map to reduce the number of refinement patches without materializing
1183
+ # a full-image planning tensor.
1184
+ diff_local[local_slicer] = iterative_3x3_same_padding_pool3d(
1185
+ diff_local[local_slicer][None, None], kernel_size=5, use_min_pool=True
1186
+ )[0, 0]
1187
+ diff_local[local_slicer] = iterative_3x3_same_padding_pool3d(
1188
+ diff_local[local_slicer][None, None], kernel_size=5, use_min_pool=False
1189
+ )[0, 0]
1190
+
1191
+ return diff_local.to(torch.uint8)
1192
+
1193
+ def _mark_prev_seg_in_local_diff(self, diff_local: torch.Tensor, planning_bbox: List[List[int]]) -> None:
1194
+ prev_seg_ch = self._get_prev_seg_channel()
1195
+ planning_slicer = tuple(slice(lb, ub) for lb, ub in planning_bbox)
1196
+ prev_sub = torch.from_numpy(np.asarray(self.interactions[(prev_seg_ch, *planning_slicer)])).to(self.device)
1197
+ diff_local[prev_sub > 0.5] = 1
1198
+ del prev_sub
1199
+
1200
+ def _plan_refinement_bboxes(
1201
+ self, pred: torch.Tensor, scaled_bbox: List[List[int]], force_full_refine: bool
1202
+ ) -> List[List[List[int]]]:
1203
+ spatial_shape = tuple(int(i) for i in self.interactions.shape[1:])
1204
+ planning_bbox = self._clip_bbox_to_shape(scaled_bbox, spatial_shape)
1205
+
1206
+ if force_full_refine:
1207
+ print("Forcing full refinement of entire structure")
1208
+ prev_seg_bbox = self._compute_prev_seg_positive_bbox()
1209
+ planning_bbox = self._union_bboxes(planning_bbox, prev_seg_bbox)
1210
+
1211
+ if planning_bbox is None:
1212
+ center = self.new_interaction_centers[-1]
1213
+ return [
1214
+ [[ci - pi // 2, ci - pi // 2 + pi] for ci, pi in zip(center, self.configuration_manager.patch_size)]
1215
+ ]
1216
+
1217
+ diff_local = self._compute_local_diff_map(pred, scaled_bbox, planning_bbox)
1218
+ if force_full_refine:
1219
+ self._mark_prev_seg_in_local_diff(diff_local, planning_bbox)
1220
+
1221
+ local_bboxes = generate_bounding_boxes(
1222
+ diff_local, self.configuration_manager.patch_size, stride="auto", margin=(24, 24, 24), max_depth=3
1223
+ )
1224
+ del diff_local
1225
+ empty_cache(self.device)
1226
+
1227
+ # If no bounding boxes are returned we basically have almost no changes. Still we should at least perform
1228
+ # refinement in the bounding box where the interaction was as the user evidently wanted something here.
1229
+ if len(local_bboxes) == 0:
1230
+ center = self.new_interaction_centers[-1]
1231
+ return [
1232
+ [[ci - pi // 2, ci - pi // 2 + pi] for ci, pi in zip(center, self.configuration_manager.patch_size)]
1233
+ ]
1234
+
1235
+ return self._offset_bboxes(local_bboxes, planning_bbox)
1236
+
1237
+ def _add_patch_for_point_interaction(self, coordinates):
1238
+ self.new_interaction_zoom_out_factors.append(1)
1239
+ self.new_interaction_centers.append(coordinates)
1240
+ print(
1241
+ f"Added new point interaction: center {self.new_interaction_zoom_out_factors[-1]}, scale {self.new_interaction_centers}"
1242
+ )
1243
+
1244
+ def _add_patch_for_bbox_interaction(self, bbox):
1245
+ bbox_center = [round((i[0] + i[1]) / 2) for i in bbox]
1246
+ bbox_size = [i[1] - i[0] for i in bbox]
1247
+ # we want to see some context, so the crop we see for the initial prediction should be patch_size / 3 larger
1248
+ requested_size = [i + j // 3 for i, j in zip(bbox_size, self.configuration_manager.patch_size)]
1249
+ self.new_interaction_zoom_out_factors.append(
1250
+ max(1, max([i / j for i, j in zip(requested_size, self.configuration_manager.patch_size)]))
1251
+ )
1252
+ self.new_interaction_centers.append(bbox_center)
1253
+ print(
1254
+ f"Added new bbox interaction: center {self.new_interaction_zoom_out_factors[-1]}, scale {self.new_interaction_centers}"
1255
+ )
1256
+
1257
+ def _add_patch_for_initial_seg_interaction(self, initial_seg):
1258
+ return self._generic_add_patch_from_image(initial_seg)
1259
+
1260
+ def _generic_add_patch_from_image(self, image: torch.Tensor, offset: Optional[List[int]] = None):
1261
+ if not torch.any(image):
1262
+ print("Received empty image prompt. Cannot add patches for prediction")
1263
+ return
1264
+ if offset is None:
1265
+ offset = [0] * image.ndim
1266
+ nonzero_indices = torch.nonzero(image, as_tuple=False)
1267
+ mn = torch.min(nonzero_indices, dim=0)[0]
1268
+ mx = torch.max(nonzero_indices, dim=0)[0]
1269
+ roi = [[i.item() + off, x.item() + off + 1] for i, x, off in zip(mn, mx, offset)]
1270
+ roi_center = [round((i[0] + i[1]) / 2) for i in roi]
1271
+ roi_size = [i[1] - i[0] for i in roi]
1272
+ requested_size = [i + j // 3 for i, j in zip(roi_size, self.configuration_manager.patch_size)]
1273
+ self.new_interaction_zoom_out_factors.append(
1274
+ max(1, max([i / j for i, j in zip(requested_size, self.configuration_manager.patch_size)]))
1275
+ )
1276
+ self.new_interaction_centers.append(roi_center)
1277
+ print(
1278
+ f"Added new image interaction: scale {self.new_interaction_zoom_out_factors[-1]}, center {self.new_interaction_centers}"
1279
+ )
1280
+
1281
+ def initialize_from_trained_model_folder(
1282
+ self,
1283
+ model_training_output_dir: str,
1284
+ use_fold: Union[int, str] = None,
1285
+ checkpoint_name: str = "checkpoint_final.pth",
1286
+ ):
1287
+ """
1288
+ This is used when making predictions with a trained model
1289
+ """
1290
+ point_interaction_use_etd = True
1291
+ (
1292
+ capability_content,
1293
+ point_interaction_radius,
1294
+ self.preferred_scribble_thickness,
1295
+ self.interaction_decay,
1296
+ self.pad_mode_data,
1297
+ ) = self._load_capability_and_runtime_defaults(model_training_output_dir)
1298
+
1299
+ self.point_interaction = PointInteraction_stub(point_interaction_radius, point_interaction_use_etd)
1300
+ self._apply_capability(capability_content)
1301
+
1302
+ dataset_json = load_json(join(model_training_output_dir, "dataset.json"))
1303
+ plans = load_json(join(model_training_output_dir, "plans.json"))
1304
+ plans_manager = PlansManager(plans)
1305
+
1306
+ if use_fold is not None:
1307
+ use_fold = int(use_fold) if use_fold != "all" else use_fold
1308
+ fold_folder = f"fold_{use_fold}"
1309
+ else:
1310
+ fldrs = subdirs(model_training_output_dir, prefix="fold_", join=False)
1311
+ assert len(fldrs) == 1, f"Attempted to infer fold but there is != 1 fold_ folders: {fldrs}"
1312
+ fold_folder = fldrs[0]
1313
+
1314
+ checkpoint = torch.load(
1315
+ join(model_training_output_dir, fold_folder, checkpoint_name), map_location=self.device, weights_only=False
1316
+ )
1317
+ if self._is_official_checkpoint(plans, checkpoint):
1318
+ print(
1319
+ "License reminder: The official nnInteractive checkpoint is licensed under "
1320
+ "Creative Commons Attribution Non Commercial Share Alike 4.0 (CC BY-NC-SA 4.0). "
1321
+ "See the license note in readme.md (# License)."
1322
+ )
1323
+ trainer_name = checkpoint["trainer_name"]
1324
+ configuration_name = checkpoint["init_args"]["configuration"]
1325
+
1326
+ parameters = checkpoint["network_weights"]
1327
+
1328
+ configuration_manager = plans_manager.get_configuration(configuration_name)
1329
+ # restore network
1330
+ num_input_channels = (
1331
+ determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
1332
+ + self.num_interaction_channels
1333
+ )
1334
+ trainer_class = recursive_find_python_class(
1335
+ join(nnInteractive.__path__[0], "trainer"), trainer_name, "nnInteractive.trainer"
1336
+ )
1337
+ if trainer_class is None:
1338
+ print(
1339
+ f"Unable to locate trainer class {trainer_name} in nnInteractive.trainer. "
1340
+ f"Please place it there (in any .py file)!"
1341
+ )
1342
+ print(
1343
+ "Attempting to use default nnInteractiveTrainer_stub. If you encounter errors, this is where you need to look!"
1344
+ )
1345
+ trainer_class = nnInteractiveTrainer_stub
1346
+
1347
+ network = trainer_class.build_network_architecture(
1348
+ plans_manager,
1349
+ configuration_manager,
1350
+ num_input_channels,
1351
+ plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
1352
+ enable_deep_supervision=False,
1353
+ ).to(self.device)
1354
+ network.load_state_dict(parameters)
1355
+
1356
+ self.plans_manager = plans_manager
1357
+ self.configuration_manager = configuration_manager
1358
+ self.network = network
1359
+ self.dataset_json = dataset_json
1360
+ self.trainer_name = trainer_name
1361
+ self.label_manager = plans_manager.get_label_manager(dataset_json)
1362
+ if self.use_torch_compile and not isinstance(self.network, OptimizedModule):
1363
+ print("Using torch.compile")
1364
+ self.network = torch.compile(self.network)
1365
+
1366
+ def manual_initialization(
1367
+ self,
1368
+ network: nn.Module,
1369
+ plans_manager: PlansManager,
1370
+ configuration_manager: ConfigurationManager,
1371
+ dataset_json: dict,
1372
+ trainer_name: str,
1373
+ ):
1374
+ """
1375
+ This is used by the nnUNetTrainer to initialize nnUNetPredictor for the final validation
1376
+ """
1377
+ self.plans_manager = plans_manager
1378
+ self.configuration_manager = configuration_manager
1379
+ self.network = network.to(self.device)
1380
+ self.dataset_json = dataset_json
1381
+ self.trainer_name = trainer_name
1382
+ self.label_manager = plans_manager.get_label_manager(dataset_json)
1383
+
1384
+ if self.use_torch_compile and not isinstance(self.network, OptimizedModule):
1385
+ print("Using torch.compile")
1386
+ self.network = torch.compile(self.network)
1387
+
1388
+ if not self.use_torch_compile and isinstance(self.network, OptimizedModule):
1389
+ self.network = self.network._orig_mod
1390
+
1391
+ self.network = self.network.to(self.device)
1392
+
1393
+ def __del__(self):
1394
+ self._finish_preprocessing_and_initialize_interactions()
1395
+ self.executor.shutdown()
1396
+
1397
+
1398
+ if __name__ == "__main__":
1399
+ a = torch.zeros((160, 160, 160), device="cpu")
1400
+ a.index_select(0, torch.tensor([0]))