nnInteractive 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nnInteractive/__init__.py +3 -0
- nnInteractive/inference/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
- nnInteractive/inference/inference_session.py +1400 -0
- nnInteractive/interaction/__init__.py +0 -0
- nnInteractive/interaction/point.py +166 -0
- nnInteractive/supervoxel/setup.py +4 -0
- nnInteractive/supervoxel/src/metadata.py +118 -0
- nnInteractive/supervoxel/src/reader.py +175 -0
- nnInteractive/supervoxel/src/run.py +136 -0
- nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
- nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
- nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
- nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
- nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
- nnInteractive/supervoxel/src/sam2/setup.py +174 -0
- nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
- nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
- nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
- nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
- nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
- nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
- nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
- nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
- nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
- nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
- nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
- nnInteractive/supervoxel/src/supervoxel.py +198 -0
- nnInteractive/trainer/__init__.py +0 -0
- nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
- nnInteractive/utils/__init__.py +0 -0
- nnInteractive/utils/bboxes.py +217 -0
- nnInteractive/utils/checkpoint_cleansing.py +9 -0
- nnInteractive/utils/crop.py +268 -0
- nnInteractive/utils/erosion_dilation.py +48 -0
- nnInteractive/utils/inference_helpers.py +45 -0
- nnInteractive/utils/os_shennanigans.py +16 -0
- nnInteractive/utils/rounding.py +13 -0
- nninteractive-2.0.0.dist-info/METADATA +511 -0
- nninteractive-2.0.0.dist-info/RECORD +76 -0
- nninteractive-2.0.0.dist-info/WHEEL +5 -0
- nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
- nninteractive-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1400 @@
|
|
|
1
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
2
|
+
import os
|
|
3
|
+
from os import cpu_count
|
|
4
|
+
from time import time
|
|
5
|
+
from typing import Union, List, Tuple, Optional
|
|
6
|
+
import warnings
|
|
7
|
+
|
|
8
|
+
import blosc2
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice, crop_and_pad_nd
|
|
13
|
+
from batchgenerators.utilities.file_and_folder_operations import load_json, join, subdirs, isfile
|
|
14
|
+
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
|
|
15
|
+
from nnunetv2.utilities.helpers import dummy_context, empty_cache
|
|
16
|
+
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
|
|
17
|
+
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
|
|
18
|
+
from torch import nn
|
|
19
|
+
from torch._dynamo import OptimizedModule
|
|
20
|
+
from torch.nn.functional import interpolate
|
|
21
|
+
|
|
22
|
+
import nnInteractive
|
|
23
|
+
from nnInteractive.interaction.point import PointInteraction_stub
|
|
24
|
+
from nnInteractive.trainer.nnInteractiveTrainer import nnInteractiveTrainer_stub
|
|
25
|
+
from nnInteractive.utils.bboxes import generate_bounding_boxes
|
|
26
|
+
from nnInteractive.utils.crop import crop_and_pad_into_buffer, paste_tensor, pad_cropped, crop_to_valid
|
|
27
|
+
from nnInteractive.utils.erosion_dilation import iterative_3x3_same_padding_pool3d
|
|
28
|
+
from nnInteractive.utils.inference_helpers import (
|
|
29
|
+
infer_num_interaction_channels_from_mapping,
|
|
30
|
+
parse_channel_pair,
|
|
31
|
+
transform_coordinates_noresampling,
|
|
32
|
+
version_to_tuple,
|
|
33
|
+
)
|
|
34
|
+
from nnInteractive.utils.rounding import round_to_nearest_odd
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class nnInteractiveInferenceSession:
|
|
38
|
+
INFERENCE_SESSION_VERSION = nnInteractive.__version__
|
|
39
|
+
REFINEMENT_CACHE_GPU_HEADROOM_BYTES = 4 * 1024**3
|
|
40
|
+
# Interactions implemented by this inference session.
|
|
41
|
+
SUPPORTED_INTERACTION_KEYS = ("scribble", "lasso", "points", "bbox2d", "bbox3d")
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
device: torch.device = torch.device("cuda"),
|
|
46
|
+
use_torch_compile: bool = False,
|
|
47
|
+
verbose: bool = False,
|
|
48
|
+
torch_n_threads: int = 8,
|
|
49
|
+
do_autozoom: bool = True,
|
|
50
|
+
):
|
|
51
|
+
"""
|
|
52
|
+
Only intended to work with nnInteractiveTrainerV2 and its derivatives
|
|
53
|
+
"""
|
|
54
|
+
print("session initialized")
|
|
55
|
+
|
|
56
|
+
# set as part of initialization
|
|
57
|
+
assert use_torch_compile is False, (
|
|
58
|
+
"torch.compile is not supported. The blosc2-backed interaction tensor "
|
|
59
|
+
"requires numpy↔torch round-trips that break compile tracing."
|
|
60
|
+
)
|
|
61
|
+
self.network = None
|
|
62
|
+
self.label_manager = None
|
|
63
|
+
self.dataset_json = None
|
|
64
|
+
self.trainer_name = None
|
|
65
|
+
self.configuration_manager = None
|
|
66
|
+
self.plans_manager = None
|
|
67
|
+
self._interactions_shape = None
|
|
68
|
+
self.device = device
|
|
69
|
+
self.use_torch_compile = use_torch_compile
|
|
70
|
+
self.interaction_decay = None
|
|
71
|
+
self.current_interaction_intensity: float = 1.0
|
|
72
|
+
self._fp16_max_value = float(torch.finfo(torch.float16).max)
|
|
73
|
+
# Keep renormalized interaction magnitudes around 1/10 of fp16 max to preserve headroom.
|
|
74
|
+
self._interaction_renorm_target = self._fp16_max_value / 10
|
|
75
|
+
self.num_interaction_channels: int = None
|
|
76
|
+
self.supported_interactions: dict = {}
|
|
77
|
+
self.channel_mapping: dict = {}
|
|
78
|
+
self.supports_initial_label: bool = True
|
|
79
|
+
self.supports_zero_shot_label_refinement: bool = True
|
|
80
|
+
|
|
81
|
+
# image specific
|
|
82
|
+
self.interactions = None # blosc2.NDArray once initialized
|
|
83
|
+
self.preprocessed_image: torch.Tensor = None
|
|
84
|
+
self.preprocessed_props = None
|
|
85
|
+
self.target_buffer: Union[np.ndarray, torch.Tensor] = None
|
|
86
|
+
|
|
87
|
+
# this will be set when loading the model (initialize_from_trained_model_folder)
|
|
88
|
+
self.pad_mode_data = self.preferred_scribble_thickness = self.point_interaction = None
|
|
89
|
+
|
|
90
|
+
self.verbose = verbose
|
|
91
|
+
|
|
92
|
+
self.do_autozoom: bool = do_autozoom
|
|
93
|
+
|
|
94
|
+
torch.set_num_threads(min(torch_n_threads, cpu_count()))
|
|
95
|
+
self.torch_n_threads = torch_n_threads
|
|
96
|
+
|
|
97
|
+
self.original_image_shape = None
|
|
98
|
+
|
|
99
|
+
self.new_interaction_zoom_out_factors: List[float] = []
|
|
100
|
+
self.new_interaction_centers = []
|
|
101
|
+
# Create a thread pool executor for background tasks.
|
|
102
|
+
# this only takes care of preprocessing and interaction memory initialization so there is no need to give it
|
|
103
|
+
# more than 2 workers
|
|
104
|
+
self.executor = ThreadPoolExecutor(max_workers=2)
|
|
105
|
+
self.preprocess_future = None
|
|
106
|
+
self.interactions_future = None
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def _is_official_checkpoint(plans: dict, checkpoint: dict) -> bool:
|
|
110
|
+
return (
|
|
111
|
+
plans.get("dataset_name") == "Dataset225_nnInteractiveV2"
|
|
112
|
+
and checkpoint.get("init_args", {}).get("configuration") == "3d_fullres_ps192_bs24"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def _legacy_default_capability(self) -> dict:
|
|
116
|
+
return {
|
|
117
|
+
"supported_interactions": {
|
|
118
|
+
"scribble": True,
|
|
119
|
+
"lasso": True,
|
|
120
|
+
"points": True,
|
|
121
|
+
"bbox2d": True,
|
|
122
|
+
"bbox3d": False,
|
|
123
|
+
},
|
|
124
|
+
"supports_initial_label": True,
|
|
125
|
+
"supports_zero_shot_label_refinement": True,
|
|
126
|
+
"interaction_channels": 6,
|
|
127
|
+
"channel_mapping": {
|
|
128
|
+
"prev_seg": 0,
|
|
129
|
+
"bbox2d": (1, 2),
|
|
130
|
+
"bbox3d": (1, 2),
|
|
131
|
+
"lasso": (1, 2),
|
|
132
|
+
"points": (3, 4),
|
|
133
|
+
"scribble": (5, 6),
|
|
134
|
+
},
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
def _to_positive_channel_index(self, idx: int) -> int:
|
|
138
|
+
return idx if idx >= 0 else self.num_interaction_channels + idx
|
|
139
|
+
|
|
140
|
+
def _resolve_channel_pair(self, channel_name: str, override_capability_checks: bool) -> Tuple[int, int]:
|
|
141
|
+
if channel_name in self.channel_mapping:
|
|
142
|
+
return parse_channel_pair(channel_name, self.channel_mapping[channel_name])
|
|
143
|
+
if override_capability_checks:
|
|
144
|
+
warnings.warn(
|
|
145
|
+
f"Interaction '{channel_name}' was forced but no channel mapping exists in capability metadata.",
|
|
146
|
+
RuntimeWarning,
|
|
147
|
+
)
|
|
148
|
+
raise ValueError(f"Interaction '{channel_name}' cannot be executed because no channel mapping was found.")
|
|
149
|
+
|
|
150
|
+
def _is_interaction_supported(self, interaction_name: str) -> bool:
|
|
151
|
+
if interaction_name in self.SUPPORTED_INTERACTION_KEYS:
|
|
152
|
+
return bool(self.supported_interactions.get(interaction_name, False))
|
|
153
|
+
if interaction_name == "initial_label":
|
|
154
|
+
return bool(self.supports_initial_label)
|
|
155
|
+
return False
|
|
156
|
+
|
|
157
|
+
def _get_prev_seg_channel(self) -> int:
|
|
158
|
+
return int(self.channel_mapping["prev_seg"])
|
|
159
|
+
|
|
160
|
+
@staticmethod
|
|
161
|
+
def _clip_bbox_to_shape(bbox: List[List[int]], spatial_shape: Tuple[int, ...]) -> Optional[List[List[int]]]:
|
|
162
|
+
clipped = [[max(0, int(lb)), min(int(ub), int(s))] for (lb, ub), s in zip(bbox, spatial_shape)]
|
|
163
|
+
if any(ub <= lb for lb, ub in clipped):
|
|
164
|
+
return None
|
|
165
|
+
return clipped
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def _bbox_size(bbox: List[List[int]]) -> List[int]:
|
|
169
|
+
return [int(ub - lb) for lb, ub in bbox]
|
|
170
|
+
|
|
171
|
+
@staticmethod
|
|
172
|
+
def _union_bboxes(*bboxes: Optional[List[List[int]]]) -> Optional[List[List[int]]]:
|
|
173
|
+
valid_bboxes = [bbox for bbox in bboxes if bbox is not None]
|
|
174
|
+
if len(valid_bboxes) == 0:
|
|
175
|
+
return None
|
|
176
|
+
return [
|
|
177
|
+
[min(bbox[dim][0] for bbox in valid_bboxes), max(bbox[dim][1] for bbox in valid_bboxes)]
|
|
178
|
+
for dim in range(len(valid_bboxes[0]))
|
|
179
|
+
]
|
|
180
|
+
|
|
181
|
+
@staticmethod
|
|
182
|
+
def _offset_bboxes(local_bboxes: List[List[List[int]]], offset_bbox: List[List[int]]) -> List[List[List[int]]]:
|
|
183
|
+
return [
|
|
184
|
+
[[lb + offset_bbox[dim][0], ub + offset_bbox[dim][0]] for dim, (lb, ub) in enumerate(bbox)]
|
|
185
|
+
for bbox in local_bboxes
|
|
186
|
+
]
|
|
187
|
+
|
|
188
|
+
def _interaction_bbox_to_target_bbox(self, bbox: List[List[int]]) -> List[List[int]]:
|
|
189
|
+
return [
|
|
190
|
+
[i[0] + bbc[0], i[1] + bbc[0]] for i, bbc in zip(bbox, self.preprocessed_props["bbox_used_for_cropping"])
|
|
191
|
+
]
|
|
192
|
+
|
|
193
|
+
def _compute_prev_seg_positive_bbox(self) -> Optional[List[List[int]]]:
|
|
194
|
+
prev_seg_ch = self._get_prev_seg_channel()
|
|
195
|
+
spatial_shape = tuple(int(i) for i in self.interactions.shape[1:])
|
|
196
|
+
|
|
197
|
+
occupancy_x = np.zeros(spatial_shape[0], dtype=bool)
|
|
198
|
+
occupancy_y = np.zeros(spatial_shape[1], dtype=bool)
|
|
199
|
+
occupancy_z = np.zeros(spatial_shape[2], dtype=bool)
|
|
200
|
+
chunk_depth = 64
|
|
201
|
+
for d0 in range(0, spatial_shape[0], chunk_depth):
|
|
202
|
+
d1 = min(spatial_shape[0], d0 + chunk_depth)
|
|
203
|
+
slab = np.asarray(self.interactions[(prev_seg_ch, slice(d0, d1), slice(None), slice(None))]) > 0.5
|
|
204
|
+
if not slab.any():
|
|
205
|
+
continue
|
|
206
|
+
occupancy_x[d0:d1] |= np.any(slab, axis=(1, 2))
|
|
207
|
+
occupancy_y |= np.any(slab, axis=(0, 2))
|
|
208
|
+
occupancy_z |= np.any(slab, axis=(0, 1))
|
|
209
|
+
|
|
210
|
+
occupancies = (occupancy_x, occupancy_y, occupancy_z)
|
|
211
|
+
bbox = []
|
|
212
|
+
for occ in occupancies:
|
|
213
|
+
indices = np.flatnonzero(occ)
|
|
214
|
+
if len(indices) == 0:
|
|
215
|
+
return None
|
|
216
|
+
bbox.append([int(indices[0]), int(indices[-1]) + 1])
|
|
217
|
+
return bbox
|
|
218
|
+
|
|
219
|
+
def _get_dilation_channels_for_resample(self) -> List[int]:
|
|
220
|
+
dilation_channels = set()
|
|
221
|
+
# During zoom-out, point/scribble signals can disappear when area interpolation averages tiny sparse
|
|
222
|
+
# structures away. We therefore dilate only these "thin prompt" channels before resampling.
|
|
223
|
+
for key in ("points", "scribble"):
|
|
224
|
+
if not self.supported_interactions.get(key, False):
|
|
225
|
+
continue
|
|
226
|
+
if key not in self.channel_mapping:
|
|
227
|
+
continue
|
|
228
|
+
pos_ch, neg_ch = parse_channel_pair(key, self.channel_mapping[key])
|
|
229
|
+
dilation_channels.add(pos_ch)
|
|
230
|
+
dilation_channels.add(neg_ch)
|
|
231
|
+
# Use a sorted list so execution is deterministic and easier to reason about in debugging/logging.
|
|
232
|
+
return sorted(dilation_channels)
|
|
233
|
+
|
|
234
|
+
def _check_capability_or_warn(self, interaction_name: str, override_capability_checks: bool):
|
|
235
|
+
if self._is_interaction_supported(interaction_name):
|
|
236
|
+
return
|
|
237
|
+
msg = f"Interaction '{interaction_name}' is not supported by this checkpoint capability metadata."
|
|
238
|
+
if override_capability_checks:
|
|
239
|
+
warnings.warn(f"{msg} Proceeding because override_capability_checks=True.", RuntimeWarning)
|
|
240
|
+
return
|
|
241
|
+
raise ValueError(msg)
|
|
242
|
+
|
|
243
|
+
def _get_non_prev_seg_channels(self) -> List[int]:
|
|
244
|
+
if self.interactions is None:
|
|
245
|
+
return []
|
|
246
|
+
prev_seg_channel = self._get_prev_seg_channel()
|
|
247
|
+
channels = list(range(self.interactions.shape[0]))
|
|
248
|
+
if prev_seg_channel in channels:
|
|
249
|
+
channels.remove(prev_seg_channel)
|
|
250
|
+
return channels
|
|
251
|
+
|
|
252
|
+
def _renormalize_interactions_if_needed(self):
|
|
253
|
+
if self.interactions is None:
|
|
254
|
+
return
|
|
255
|
+
if self.current_interaction_intensity <= self._fp16_max_value:
|
|
256
|
+
return
|
|
257
|
+
channels_to_scale = self._get_non_prev_seg_channels()
|
|
258
|
+
if len(channels_to_scale) == 0:
|
|
259
|
+
self.current_interaction_intensity = min(
|
|
260
|
+
self.current_interaction_intensity, self._interaction_renorm_target
|
|
261
|
+
)
|
|
262
|
+
return
|
|
263
|
+
scale = self._interaction_renorm_target / self.current_interaction_intensity
|
|
264
|
+
for ch in channels_to_scale:
|
|
265
|
+
self.interactions[ch] *= scale
|
|
266
|
+
self.current_interaction_intensity = self._interaction_renorm_target
|
|
267
|
+
|
|
268
|
+
def _interactions_inplace_maximum(self, channel_idx: int, int_slicer, new_values) -> None:
|
|
269
|
+
"""In-place element-wise maximum for a subregion of a channel."""
|
|
270
|
+
if isinstance(new_values, torch.Tensor):
|
|
271
|
+
new_values = new_values.cpu().numpy().astype(np.float16)
|
|
272
|
+
full_slicer = (channel_idx, *int_slicer)
|
|
273
|
+
current_sub = np.asarray(self.interactions[full_slicer])
|
|
274
|
+
np.maximum(current_sub, new_values, out=current_sub)
|
|
275
|
+
self.interactions[full_slicer] = current_sub
|
|
276
|
+
|
|
277
|
+
def _write_interactions_channel(self, channel_idx: int, value) -> None:
|
|
278
|
+
"""Write a full channel. Handles torch→numpy for blosc2."""
|
|
279
|
+
if isinstance(value, torch.Tensor):
|
|
280
|
+
value = value.cpu().numpy().astype(np.float16)
|
|
281
|
+
self.interactions[channel_idx] = value
|
|
282
|
+
|
|
283
|
+
def _paste_prediction_to_target_buffer(self, prediction: torch.Tensor, bbox: List[List[int]]) -> None:
|
|
284
|
+
target_bbox = self._interaction_bbox_to_target_bbox(bbox)
|
|
285
|
+
if isinstance(self.target_buffer, torch.Tensor):
|
|
286
|
+
pred_for_target = prediction.to(self.target_buffer.device)
|
|
287
|
+
else:
|
|
288
|
+
pred_for_target = prediction.to("cpu")
|
|
289
|
+
paste_tensor(self.target_buffer, pred_for_target, target_bbox)
|
|
290
|
+
|
|
291
|
+
def _estimate_refinement_cache_nbytes(self, cache_bbox: List[List[int]]) -> int:
|
|
292
|
+
cache_voxels = int(np.prod(self._bbox_size(cache_bbox), dtype=np.int64))
|
|
293
|
+
image_nbytes = cache_voxels * torch.empty((), dtype=self.preprocessed_image.dtype).element_size()
|
|
294
|
+
interactions_nbytes = (
|
|
295
|
+
cache_voxels * self.num_interaction_channels * torch.empty((), dtype=torch.float16).element_size()
|
|
296
|
+
)
|
|
297
|
+
return int(image_nbytes + interactions_nbytes)
|
|
298
|
+
|
|
299
|
+
def _select_refinement_cache_device(self, cache_bbox: List[List[int]]) -> torch.device:
|
|
300
|
+
if self.device.type != "cuda":
|
|
301
|
+
return torch.device("cpu")
|
|
302
|
+
|
|
303
|
+
cache_nbytes = self._estimate_refinement_cache_nbytes(cache_bbox)
|
|
304
|
+
try:
|
|
305
|
+
free_mem, _ = torch.cuda.mem_get_info(self.device)
|
|
306
|
+
if free_mem - cache_nbytes >= self.REFINEMENT_CACHE_GPU_HEADROOM_BYTES:
|
|
307
|
+
return self.device
|
|
308
|
+
except Exception:
|
|
309
|
+
pass
|
|
310
|
+
|
|
311
|
+
return torch.device("cpu")
|
|
312
|
+
|
|
313
|
+
def _build_refinement_local_cache(self, bboxes_ordered: List[List[List[int]]]):
|
|
314
|
+
cache_bbox = self._union_bboxes(*bboxes_ordered)
|
|
315
|
+
cache_device = self._select_refinement_cache_device(cache_bbox)
|
|
316
|
+
cache_shape = self._bbox_size(cache_bbox)
|
|
317
|
+
pin_cache = cache_device.type == "cpu" and self.device.type == "cuda"
|
|
318
|
+
|
|
319
|
+
cache_kwargs = {"device": cache_device}
|
|
320
|
+
if pin_cache:
|
|
321
|
+
cache_kwargs["pin_memory"] = True
|
|
322
|
+
|
|
323
|
+
cache_image = torch.zeros(cache_shape, dtype=self.preprocessed_image.dtype, **cache_kwargs)
|
|
324
|
+
cache_interactions = torch.zeros(
|
|
325
|
+
(self.num_interaction_channels, *cache_shape), dtype=torch.float16, **cache_kwargs
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
crop_and_pad_into_buffer(cache_image, cache_bbox, self.preprocessed_image[0])
|
|
329
|
+
crop_and_pad_into_buffer(cache_interactions, cache_bbox, self.interactions)
|
|
330
|
+
self._normalize_interaction_channels_for_network_(cache_interactions)
|
|
331
|
+
return cache_bbox, cache_image, cache_interactions
|
|
332
|
+
|
|
333
|
+
def _prepare_new_interaction_intensity(self):
|
|
334
|
+
if self.interaction_decay is None:
|
|
335
|
+
return
|
|
336
|
+
if not (0 < self.interaction_decay <= 1):
|
|
337
|
+
raise ValueError(f"interaction_decay must be in (0, 1], got {self.interaction_decay}.")
|
|
338
|
+
if self.interaction_decay < 1:
|
|
339
|
+
self.current_interaction_intensity *= 1 / self.interaction_decay
|
|
340
|
+
self._renormalize_interactions_if_needed()
|
|
341
|
+
|
|
342
|
+
def _normalize_interaction_channels_for_network_(self, interaction_tensor: torch.Tensor):
|
|
343
|
+
if interaction_tensor is None or self.current_interaction_intensity == 0:
|
|
344
|
+
return
|
|
345
|
+
if self.current_interaction_intensity == 1:
|
|
346
|
+
return
|
|
347
|
+
prev_seg_channel = self._get_prev_seg_channel()
|
|
348
|
+
for ch in range(interaction_tensor.shape[0]):
|
|
349
|
+
if ch != prev_seg_channel:
|
|
350
|
+
interaction_tensor[ch] /= self.current_interaction_intensity
|
|
351
|
+
|
|
352
|
+
def _load_capability_and_runtime_defaults(self, model_training_output_dir: str):
|
|
353
|
+
capability_file = join(model_training_output_dir, "inference_info.json")
|
|
354
|
+
legacy_file = join(model_training_output_dir, "inference_session_class.json")
|
|
355
|
+
|
|
356
|
+
point_interaction_radius = 4
|
|
357
|
+
preferred_scribble_thickness = [2, 2, 2]
|
|
358
|
+
interaction_decay = 0.98
|
|
359
|
+
pad_mode_data = "constant"
|
|
360
|
+
capability_content = {}
|
|
361
|
+
|
|
362
|
+
# Prefer modern capability metadata; fall back to legacy session metadata for older checkpoints.
|
|
363
|
+
if isfile(capability_file):
|
|
364
|
+
capability_content = load_json(capability_file)
|
|
365
|
+
if not isinstance(capability_content, dict):
|
|
366
|
+
raise RuntimeError(f"Invalid capability metadata in {capability_file}. Expected a JSON object.")
|
|
367
|
+
self._validate_capability_version(capability_content)
|
|
368
|
+
point_interaction_radius = capability_content.get("point_radius", point_interaction_radius)
|
|
369
|
+
preferred_scribble_thickness = capability_content.get(
|
|
370
|
+
"preferred_scribble_thickness", preferred_scribble_thickness
|
|
371
|
+
)
|
|
372
|
+
interaction_decay = capability_content.get("interaction_decay", interaction_decay)
|
|
373
|
+
pad_mode_data = capability_content.get("pad_mode_image", pad_mode_data)
|
|
374
|
+
elif isfile(legacy_file):
|
|
375
|
+
legacy_content = load_json(legacy_file)
|
|
376
|
+
if isinstance(legacy_content, str):
|
|
377
|
+
interaction_decay = 0.9
|
|
378
|
+
else:
|
|
379
|
+
point_interaction_radius = legacy_content.get("point_radius", point_interaction_radius)
|
|
380
|
+
preferred_scribble_thickness = legacy_content.get(
|
|
381
|
+
"preferred_scribble_thickness", preferred_scribble_thickness
|
|
382
|
+
)
|
|
383
|
+
interaction_decay = legacy_content.get("interaction_decay", interaction_decay)
|
|
384
|
+
pad_mode_data = legacy_content.get("pad_mode_image", pad_mode_data)
|
|
385
|
+
else:
|
|
386
|
+
raise FileNotFoundError(
|
|
387
|
+
f"Neither capability metadata ({capability_file}) nor legacy metadata ({legacy_file}) was found."
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# Accept scalar thickness in metadata for backward compatibility.
|
|
391
|
+
if not isinstance(preferred_scribble_thickness, (tuple, list)):
|
|
392
|
+
preferred_scribble_thickness = [preferred_scribble_thickness] * 3
|
|
393
|
+
|
|
394
|
+
return (
|
|
395
|
+
capability_content,
|
|
396
|
+
point_interaction_radius,
|
|
397
|
+
preferred_scribble_thickness,
|
|
398
|
+
interaction_decay,
|
|
399
|
+
pad_mode_data,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
def _apply_capability(self, capability: dict):
|
|
403
|
+
default_capability = self._legacy_default_capability()
|
|
404
|
+
default_supported = default_capability["supported_interactions"]
|
|
405
|
+
default_mapping = default_capability["channel_mapping"]
|
|
406
|
+
supported_keys = set(self.SUPPORTED_INTERACTION_KEYS)
|
|
407
|
+
mapping_keys = set(self.SUPPORTED_INTERACTION_KEYS).union({"prev_seg"})
|
|
408
|
+
|
|
409
|
+
raw_supported = capability.get("supported_interactions", {}) if isinstance(capability, dict) else {}
|
|
410
|
+
unknown_supported = set(raw_supported.keys()) - supported_keys
|
|
411
|
+
if len(unknown_supported) > 0:
|
|
412
|
+
raise ValueError(
|
|
413
|
+
f"Capability requests unsupported interactions: {sorted(unknown_supported)}. "
|
|
414
|
+
f"Supported: {sorted(supported_keys)}"
|
|
415
|
+
)
|
|
416
|
+
filtered_supported = {k: bool(v) for k, v in raw_supported.items() if k in supported_keys}
|
|
417
|
+
self.supported_interactions = {**default_supported, **filtered_supported}
|
|
418
|
+
self.supports_initial_label = capability.get("supports_initial_label", True)
|
|
419
|
+
self.supports_zero_shot_label_refinement = capability.get("supports_zero_shot_label_refinement", True)
|
|
420
|
+
|
|
421
|
+
raw_mapping = capability.get("channel_mapping", {}) if isinstance(capability, dict) else {}
|
|
422
|
+
unknown_mapping = set(raw_mapping.keys()) - mapping_keys
|
|
423
|
+
if len(unknown_mapping) > 0:
|
|
424
|
+
raise ValueError(
|
|
425
|
+
f"Capability channel_mapping contains unsupported keys: {sorted(unknown_mapping)}. "
|
|
426
|
+
f"Supported mapping keys: {sorted(mapping_keys)}"
|
|
427
|
+
)
|
|
428
|
+
self.channel_mapping = dict(default_mapping)
|
|
429
|
+
for k, v in raw_mapping.items():
|
|
430
|
+
if k == "prev_seg":
|
|
431
|
+
self.channel_mapping[k] = int(v)
|
|
432
|
+
else:
|
|
433
|
+
self.channel_mapping[k] = parse_channel_pair(k, v)
|
|
434
|
+
|
|
435
|
+
if "interaction_channels" in capability:
|
|
436
|
+
self.num_interaction_channels = int(capability["interaction_channels"]) + 1
|
|
437
|
+
else:
|
|
438
|
+
self.num_interaction_channels = infer_num_interaction_channels_from_mapping(self.channel_mapping)
|
|
439
|
+
|
|
440
|
+
# Normalize all channel indices to positive indexing once at load time so downstream code can
|
|
441
|
+
# use direct indexing without handling negative-offset semantics repeatedly.
|
|
442
|
+
self.channel_mapping["prev_seg"] = self._to_positive_channel_index(int(self.channel_mapping["prev_seg"]))
|
|
443
|
+
for k, v in list(self.channel_mapping.items()):
|
|
444
|
+
if k == "prev_seg":
|
|
445
|
+
continue
|
|
446
|
+
pos_ch, neg_ch = parse_channel_pair(k, v)
|
|
447
|
+
self.channel_mapping[k] = (
|
|
448
|
+
self._to_positive_channel_index(pos_ch),
|
|
449
|
+
self._to_positive_channel_index(neg_ch),
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
def _validate_capability_version(self, capability: dict):
|
|
453
|
+
current_class = self.__class__.__name__
|
|
454
|
+
required_class = capability.get("inference_class", current_class)
|
|
455
|
+
if required_class != current_class:
|
|
456
|
+
raise RuntimeError(
|
|
457
|
+
f"Checkpoint requires inference class '{required_class}', but current class is " f"'{current_class}'."
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
min_version = capability.get("inference_class_min_version")
|
|
461
|
+
if min_version is None:
|
|
462
|
+
return
|
|
463
|
+
if version_to_tuple(min_version) > version_to_tuple(self.INFERENCE_SESSION_VERSION):
|
|
464
|
+
raise RuntimeError(
|
|
465
|
+
f"Checkpoint requires nnInteractiveInferenceSession>={min_version}, but current version is "
|
|
466
|
+
f"{self.INFERENCE_SESSION_VERSION}. Please update nnInteractive."
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
def set_image(self, image: np.ndarray, image_properties: dict = None):
|
|
470
|
+
"""
|
|
471
|
+
Image must be 4D to satisfy nnU-Net needs: [c, x, y, z]
|
|
472
|
+
Offload the processing to a background thread.
|
|
473
|
+
"""
|
|
474
|
+
if image_properties is None:
|
|
475
|
+
image_properties = {}
|
|
476
|
+
self._reset_session()
|
|
477
|
+
assert image.ndim == 4, f"expected a 4d image as input, got {image.ndim}d. Shape {image.shape}"
|
|
478
|
+
if self.verbose:
|
|
479
|
+
print(f"Initialize with raw image shape {image.shape}")
|
|
480
|
+
|
|
481
|
+
# Offload all image preprocessing to a background thread.
|
|
482
|
+
self.preprocess_future = self.executor.submit(self._background_set_image, image, image_properties)
|
|
483
|
+
self.original_image_shape = image.shape
|
|
484
|
+
|
|
485
|
+
def _finish_preprocessing_and_initialize_interactions(self):
|
|
486
|
+
"""
|
|
487
|
+
Block until both the image preprocessing and the interactions tensor initialization
|
|
488
|
+
are finished.
|
|
489
|
+
"""
|
|
490
|
+
if self.preprocess_future is not None:
|
|
491
|
+
# Wait for image preprocessing to complete.
|
|
492
|
+
self.preprocess_future.result()
|
|
493
|
+
del self.preprocess_future
|
|
494
|
+
self.preprocess_future = None
|
|
495
|
+
|
|
496
|
+
def set_target_buffer(self, target_buffer: Union[np.ndarray, torch.Tensor]):
|
|
497
|
+
"""
|
|
498
|
+
Must be 3d numpy array or torch.Tensor
|
|
499
|
+
"""
|
|
500
|
+
self.target_buffer = target_buffer
|
|
501
|
+
|
|
502
|
+
def set_do_autozoom(self, do_autozoom: bool):
|
|
503
|
+
self.do_autozoom = do_autozoom
|
|
504
|
+
|
|
505
|
+
def _reset_session(self):
|
|
506
|
+
self.interactions_future = None
|
|
507
|
+
self.preprocess_future = None
|
|
508
|
+
|
|
509
|
+
del self.preprocessed_image
|
|
510
|
+
del self.target_buffer
|
|
511
|
+
del self.interactions
|
|
512
|
+
del self.preprocessed_props
|
|
513
|
+
self.preprocessed_image = None
|
|
514
|
+
self.target_buffer = None
|
|
515
|
+
self.interactions = None
|
|
516
|
+
self.preprocessed_props = None
|
|
517
|
+
self.current_interaction_intensity = 1.0
|
|
518
|
+
empty_cache(self.device)
|
|
519
|
+
self.original_image_shape = None
|
|
520
|
+
|
|
521
|
+
def _initialize_interactions(self, image_torch: torch.Tensor):
|
|
522
|
+
shape = (self.num_interaction_channels, *image_torch.shape[1:])
|
|
523
|
+
if self.verbose:
|
|
524
|
+
print("Initialize interactions with blosc2 in-memory compression")
|
|
525
|
+
self.interactions = blosc2.zeros(
|
|
526
|
+
shape,
|
|
527
|
+
dtype=np.float16,
|
|
528
|
+
chunks=(1, *[min(64, s) for s in shape[1:]]),
|
|
529
|
+
blocks=(1, *[min(32, s) for s in shape[1:]]),
|
|
530
|
+
cparams={"codec": blosc2.Codec.LZ4, "clevel": 5, "nthreads": min(self.torch_n_threads, os.cpu_count())},
|
|
531
|
+
dparams={"nthreads": 4},
|
|
532
|
+
)
|
|
533
|
+
self._interactions_shape = shape
|
|
534
|
+
|
|
535
|
+
@torch.inference_mode()
|
|
536
|
+
def _background_set_image(self, image: np.ndarray, image_properties: dict):
|
|
537
|
+
# Convert and clone the image tensor.
|
|
538
|
+
image = torch.from_numpy(image.copy())
|
|
539
|
+
|
|
540
|
+
# Crop to nonzero region.
|
|
541
|
+
if self.verbose:
|
|
542
|
+
print("Cropping input image to nonzero region")
|
|
543
|
+
# torch.where eats RAM / VRAM for breakfast. Avoid!!!
|
|
544
|
+
# nonzero_idx = torch.where(image != 0)
|
|
545
|
+
# # Create bounding box: for each dimension, get the min and max (plus one) of the nonzero indices.
|
|
546
|
+
# bbox = [[i.min().item(), i.max().item() + 1] for i in nonzero_idx]
|
|
547
|
+
# del nonzero_idx
|
|
548
|
+
# instead we sum dimensions
|
|
549
|
+
s_x = image.sum(axis=(0, 2, 3), dtype=torch.float)
|
|
550
|
+
wh_x = torch.where(s_x != 0)[0]
|
|
551
|
+
bbox_x = [wh_x.min().item(), wh_x.max().item() + 1]
|
|
552
|
+
del s_x, wh_x
|
|
553
|
+
s_y = image.sum(axis=(0, 1, 3), dtype=torch.float)
|
|
554
|
+
wh_y = torch.where(s_y != 0)[0]
|
|
555
|
+
bbox_y = [wh_y.min().item(), wh_y.max().item() + 1]
|
|
556
|
+
del s_y, wh_y
|
|
557
|
+
s_z = image.sum(axis=(0, 1, 2), dtype=torch.float)
|
|
558
|
+
wh_z = torch.where(s_z != 0)[0]
|
|
559
|
+
bbox_z = [wh_z.min().item(), wh_z.max().item() + 1]
|
|
560
|
+
del s_z, wh_z
|
|
561
|
+
bbox = [[0, 1], bbox_x, bbox_y, bbox_z]
|
|
562
|
+
empty_cache(self.device)
|
|
563
|
+
|
|
564
|
+
slicer = bounding_box_to_slice(bbox) # Assuming this returns a tuple of slices.
|
|
565
|
+
image = image[slicer].float()
|
|
566
|
+
if self.verbose:
|
|
567
|
+
print(f"Cropped image shape: {image.shape}")
|
|
568
|
+
|
|
569
|
+
# As soon as we have the target shape, start initializing the interaction tensor in its own thread.
|
|
570
|
+
self.interactions_future = self.executor.submit(self._initialize_interactions, image)
|
|
571
|
+
|
|
572
|
+
# Normalize the cropped image.
|
|
573
|
+
if self.verbose:
|
|
574
|
+
print("Normalizing cropped image")
|
|
575
|
+
image -= image.mean()
|
|
576
|
+
image /= image.std()
|
|
577
|
+
|
|
578
|
+
self.preprocessed_image = image.to("cpu")
|
|
579
|
+
|
|
580
|
+
self.preprocessed_props = {"bbox_used_for_cropping": bbox[1:]}
|
|
581
|
+
|
|
582
|
+
# we need to wait for this here I believe
|
|
583
|
+
self.interactions_future.result()
|
|
584
|
+
del self.interactions_future
|
|
585
|
+
self.interactions_future = None
|
|
586
|
+
|
|
587
|
+
def reset_interactions(self):
|
|
588
|
+
"""
|
|
589
|
+
Use this to reset all interactions and start from scratch for the current image. This includes the initial
|
|
590
|
+
segmentation!
|
|
591
|
+
"""
|
|
592
|
+
if self.interactions is not None:
|
|
593
|
+
del self.interactions
|
|
594
|
+
self.interactions = blosc2.zeros(
|
|
595
|
+
self._interactions_shape,
|
|
596
|
+
dtype=np.float16,
|
|
597
|
+
chunks=(1, *[min(64, s) for s in self._interactions_shape[1:]]),
|
|
598
|
+
blocks=(1, *[min(32, s) for s in self._interactions_shape[1:]]),
|
|
599
|
+
cparams={"codec": blosc2.Codec.LZ4, "clevel": 5, "nthreads": os.cpu_count()},
|
|
600
|
+
dparams={"nthreads": 4},
|
|
601
|
+
)
|
|
602
|
+
self.current_interaction_intensity = 1.0
|
|
603
|
+
|
|
604
|
+
if self.target_buffer is not None:
|
|
605
|
+
if isinstance(self.target_buffer, np.ndarray):
|
|
606
|
+
self.target_buffer.fill(0)
|
|
607
|
+
elif isinstance(self.target_buffer, torch.Tensor):
|
|
608
|
+
self.target_buffer.zero_()
|
|
609
|
+
empty_cache(self.device)
|
|
610
|
+
|
|
611
|
+
def add_bbox_interaction(
|
|
612
|
+
self,
|
|
613
|
+
bbox_coords,
|
|
614
|
+
include_interaction: bool,
|
|
615
|
+
run_prediction: bool = True,
|
|
616
|
+
override_capability_checks: bool = False,
|
|
617
|
+
):
|
|
618
|
+
self._finish_preprocessing_and_initialize_interactions()
|
|
619
|
+
# sanity check
|
|
620
|
+
raw_bbox_size = [i[1] - i[0] for i in bbox_coords]
|
|
621
|
+
if any([i == 0 for i in raw_bbox_size]):
|
|
622
|
+
raise ValueError(f"Given bounding box size is zero in at least one dimension: {bbox_coords}")
|
|
623
|
+
|
|
624
|
+
# capability check
|
|
625
|
+
dims_with_size_one = sum(i == 1 for i in raw_bbox_size)
|
|
626
|
+
# if we do not support 3D bboxes we need to reject 3D bboxes!
|
|
627
|
+
if not self._is_interaction_supported("bbox3d") and dims_with_size_one == 0:
|
|
628
|
+
raise ValueError(
|
|
629
|
+
f"The given bounding box {bbox_coords} has size {raw_bbox_size} indicating a 3D "
|
|
630
|
+
f"bounding box. This is not supported by the loaded model checkpoint."
|
|
631
|
+
)
|
|
632
|
+
# a 2D bounding box is in principle a 3D box as well. Since 2D bboxes work better, we prefer to use a given
|
|
633
|
+
# bbox as 2d if possible (sized 1 in at least one dim and bbox2d supported)
|
|
634
|
+
bbox_kind = "bbox2d" if (dims_with_size_one >= 1 and self._is_interaction_supported("bbox2d")) else "bbox3d"
|
|
635
|
+
self._check_capability_or_warn(bbox_kind, override_capability_checks)
|
|
636
|
+
bbox_pos_channel, bbox_neg_channel = self._resolve_channel_pair(bbox_kind, override_capability_checks)
|
|
637
|
+
|
|
638
|
+
# Convert user-space coordinates (original image) to the cropped nnU-Net internal space.
|
|
639
|
+
lbs_transformed = [
|
|
640
|
+
round(i)
|
|
641
|
+
for i in transform_coordinates_noresampling(
|
|
642
|
+
[i[0] for i in bbox_coords], self.preprocessed_props["bbox_used_for_cropping"]
|
|
643
|
+
)
|
|
644
|
+
]
|
|
645
|
+
ubs_transformed = [
|
|
646
|
+
round(i)
|
|
647
|
+
for i in transform_coordinates_noresampling(
|
|
648
|
+
[i[1] for i in bbox_coords], self.preprocessed_props["bbox_used_for_cropping"]
|
|
649
|
+
)
|
|
650
|
+
]
|
|
651
|
+
transformed_bbox_coordinates = [[i, j] for i, j in zip(lbs_transformed, ubs_transformed)]
|
|
652
|
+
|
|
653
|
+
if self.verbose:
|
|
654
|
+
print(
|
|
655
|
+
f"Adding bounding box coordinates.\n"
|
|
656
|
+
f"Raw: {bbox_coords}\n"
|
|
657
|
+
f"Transformed: {transformed_bbox_coordinates}\n"
|
|
658
|
+
f"Crop Bbox: {self.preprocessed_props['bbox_used_for_cropping']}"
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
# Clip bbox to valid interaction volume and guarantee at least one voxel extent per axis.
|
|
662
|
+
image_shape = self.preprocessed_image.shape # Assuming shape is (C, H, W, D) or similar
|
|
663
|
+
|
|
664
|
+
for dim in range(len(transformed_bbox_coordinates)):
|
|
665
|
+
transformed_start, transformed_end = transformed_bbox_coordinates[dim]
|
|
666
|
+
|
|
667
|
+
# Clip to image boundaries
|
|
668
|
+
transformed_start = max(0, transformed_start)
|
|
669
|
+
transformed_end = min(image_shape[dim + 1], transformed_end) # +1 to skip channel dim
|
|
670
|
+
|
|
671
|
+
# Ensure the bounding box does not collapse to a single point
|
|
672
|
+
if transformed_end <= transformed_start:
|
|
673
|
+
if transformed_start == 0:
|
|
674
|
+
transformed_end = min(1, image_shape[dim + 1])
|
|
675
|
+
else:
|
|
676
|
+
transformed_start = max(transformed_start - 1, 0)
|
|
677
|
+
|
|
678
|
+
transformed_bbox_coordinates[dim] = [transformed_start, transformed_end]
|
|
679
|
+
|
|
680
|
+
if self.verbose:
|
|
681
|
+
print(
|
|
682
|
+
f"Bbox coordinates after clip to image boundaries and preventing dim collapse:\n"
|
|
683
|
+
f"Bbox: {transformed_bbox_coordinates}\n"
|
|
684
|
+
f"Internal image shape: {self.preprocessed_image.shape}"
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
self._add_patch_for_bbox_interaction(transformed_bbox_coordinates)
|
|
688
|
+
|
|
689
|
+
self._prepare_new_interaction_intensity()
|
|
690
|
+
|
|
691
|
+
# place bbox
|
|
692
|
+
slicer = tuple([slice(*i) for i in transformed_bbox_coordinates])
|
|
693
|
+
channel = bbox_pos_channel if include_interaction else bbox_neg_channel
|
|
694
|
+
self.interactions[(channel, *slicer)] = self.current_interaction_intensity
|
|
695
|
+
|
|
696
|
+
if run_prediction:
|
|
697
|
+
self._predict()
|
|
698
|
+
|
|
699
|
+
def add_point_interaction(
|
|
700
|
+
self,
|
|
701
|
+
coordinates: Tuple[int, ...],
|
|
702
|
+
include_interaction: bool,
|
|
703
|
+
run_prediction: bool = True,
|
|
704
|
+
override_capability_checks: bool = False,
|
|
705
|
+
):
|
|
706
|
+
self._check_capability_or_warn("points", override_capability_checks)
|
|
707
|
+
point_pos_channel, point_neg_channel = self._resolve_channel_pair("points", override_capability_checks)
|
|
708
|
+
self._finish_preprocessing_and_initialize_interactions()
|
|
709
|
+
|
|
710
|
+
transformed_coordinates = [
|
|
711
|
+
round(i)
|
|
712
|
+
for i in transform_coordinates_noresampling(coordinates, self.preprocessed_props["bbox_used_for_cropping"])
|
|
713
|
+
]
|
|
714
|
+
|
|
715
|
+
self._add_patch_for_point_interaction(transformed_coordinates)
|
|
716
|
+
|
|
717
|
+
self._prepare_new_interaction_intensity()
|
|
718
|
+
|
|
719
|
+
interaction_channel = point_pos_channel if include_interaction else point_neg_channel
|
|
720
|
+
self.point_interaction.place_point(
|
|
721
|
+
transformed_coordinates,
|
|
722
|
+
self.interactions,
|
|
723
|
+
channel_idx=interaction_channel,
|
|
724
|
+
intensity_scale=self.current_interaction_intensity,
|
|
725
|
+
)
|
|
726
|
+
if run_prediction:
|
|
727
|
+
self._predict()
|
|
728
|
+
|
|
729
|
+
def _add_image_interaction(
|
|
730
|
+
self,
|
|
731
|
+
image: np.ndarray,
|
|
732
|
+
interaction_channel: int,
|
|
733
|
+
run_prediction: bool,
|
|
734
|
+
interaction_bbox: Optional[List[List[int]]],
|
|
735
|
+
patch_fn,
|
|
736
|
+
):
|
|
737
|
+
if interaction_bbox is None:
|
|
738
|
+
interaction_bbox = [[0, s] for s in self.original_image_shape[1:]]
|
|
739
|
+
|
|
740
|
+
assert len(interaction_bbox) == 3
|
|
741
|
+
bbox_size = [ub - lb for lb, ub in interaction_bbox]
|
|
742
|
+
assert all(s > 0 for s in bbox_size), "each dimension of interaction_bbox must have positive size"
|
|
743
|
+
assert (
|
|
744
|
+
list(image.shape) == bbox_size
|
|
745
|
+
), f"image shape {list(image.shape)} must match interaction_bbox size {bbox_size}"
|
|
746
|
+
assert all(
|
|
747
|
+
lb >= 0 and ub <= orig_dim for (lb, ub), orig_dim in zip(interaction_bbox, self.original_image_shape[1:])
|
|
748
|
+
), f"interaction_bbox {interaction_bbox} exceeds original image bounds {list(self.original_image_shape[1:])}"
|
|
749
|
+
|
|
750
|
+
self._finish_preprocessing_and_initialize_interactions()
|
|
751
|
+
|
|
752
|
+
lbs_internal = [
|
|
753
|
+
round(i)
|
|
754
|
+
for i in transform_coordinates_noresampling(
|
|
755
|
+
[ib[0] for ib in interaction_bbox], self.preprocessed_props["bbox_used_for_cropping"]
|
|
756
|
+
)
|
|
757
|
+
]
|
|
758
|
+
ubs_internal = [
|
|
759
|
+
round(i)
|
|
760
|
+
for i in transform_coordinates_noresampling(
|
|
761
|
+
[ib[1] for ib in interaction_bbox], self.preprocessed_props["bbox_used_for_cropping"]
|
|
762
|
+
)
|
|
763
|
+
]
|
|
764
|
+
|
|
765
|
+
image_t = torch.from_numpy(image)
|
|
766
|
+
patch_fn(image_t, offset=lbs_internal)
|
|
767
|
+
|
|
768
|
+
self._prepare_new_interaction_intensity()
|
|
769
|
+
|
|
770
|
+
interaction_shape = self.interactions.shape[1:]
|
|
771
|
+
# Map possibly out-of-bounds transformed bbox to overlapping source/target slices so we only
|
|
772
|
+
# materialize and write the intersecting subregion.
|
|
773
|
+
clipped_lb = [max(0, lb) for lb in lbs_internal]
|
|
774
|
+
clipped_ub = [min(ub, s) for ub, s in zip(ubs_internal, interaction_shape)]
|
|
775
|
+
src_lb = [cl - lb for cl, lb in zip(clipped_lb, lbs_internal)]
|
|
776
|
+
src_ub = [src_lb[d] + (clipped_ub[d] - clipped_lb[d]) for d in range(3)]
|
|
777
|
+
int_slicer = tuple(slice(a, b) for a, b in zip(clipped_lb, clipped_ub))
|
|
778
|
+
src_slicer = tuple(slice(a, b) for a, b in zip(src_lb, src_ub))
|
|
779
|
+
new_values = image_t[src_slicer].cpu().numpy()
|
|
780
|
+
if self.current_interaction_intensity != 1:
|
|
781
|
+
new_values = new_values * self.current_interaction_intensity
|
|
782
|
+
new_values = new_values.astype(np.float16)
|
|
783
|
+
self._interactions_inplace_maximum(interaction_channel, int_slicer, new_values)
|
|
784
|
+
del new_values
|
|
785
|
+
del image_t
|
|
786
|
+
empty_cache(self.device)
|
|
787
|
+
|
|
788
|
+
if run_prediction:
|
|
789
|
+
self._predict()
|
|
790
|
+
|
|
791
|
+
def _add_mask_interaction(
|
|
792
|
+
self,
|
|
793
|
+
interaction_name: str,
|
|
794
|
+
mask_image: np.ndarray,
|
|
795
|
+
include_interaction: bool,
|
|
796
|
+
run_prediction: bool,
|
|
797
|
+
override_capability_checks: bool,
|
|
798
|
+
interaction_bbox: Optional[List[List[int]]],
|
|
799
|
+
) -> None:
|
|
800
|
+
if self.verbose:
|
|
801
|
+
print(f"Add new {interaction_name} of shape {mask_image.shape} and bbox {interaction_bbox}")
|
|
802
|
+
self._check_capability_or_warn(interaction_name, override_capability_checks)
|
|
803
|
+
pos_channel, neg_channel = self._resolve_channel_pair(interaction_name, override_capability_checks)
|
|
804
|
+
self._add_image_interaction(
|
|
805
|
+
mask_image,
|
|
806
|
+
pos_channel if include_interaction else neg_channel,
|
|
807
|
+
run_prediction,
|
|
808
|
+
interaction_bbox,
|
|
809
|
+
self._generic_add_patch_from_image,
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
def add_scribble_interaction(
|
|
813
|
+
self,
|
|
814
|
+
scribble_image: np.ndarray,
|
|
815
|
+
include_interaction: bool,
|
|
816
|
+
run_prediction: bool = True,
|
|
817
|
+
override_capability_checks: bool = False,
|
|
818
|
+
interaction_bbox: Optional[List[List[int]]] = None,
|
|
819
|
+
):
|
|
820
|
+
self._add_mask_interaction(
|
|
821
|
+
"scribble",
|
|
822
|
+
scribble_image,
|
|
823
|
+
include_interaction,
|
|
824
|
+
run_prediction,
|
|
825
|
+
override_capability_checks,
|
|
826
|
+
interaction_bbox,
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
def add_lasso_interaction(
|
|
830
|
+
self,
|
|
831
|
+
lasso_image: np.ndarray,
|
|
832
|
+
include_interaction: bool,
|
|
833
|
+
run_prediction: bool = True,
|
|
834
|
+
override_capability_checks: bool = False,
|
|
835
|
+
interaction_bbox: Optional[List[List[int]]] = None,
|
|
836
|
+
):
|
|
837
|
+
self._add_mask_interaction(
|
|
838
|
+
"lasso", lasso_image, include_interaction, run_prediction, override_capability_checks, interaction_bbox
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
def add_initial_seg_interaction(
|
|
842
|
+
self, initial_seg: np.ndarray, run_prediction: bool = False, override_capability_checks: bool = False
|
|
843
|
+
):
|
|
844
|
+
"""
|
|
845
|
+
WARNING THIS WILL RESET INTERACTIONS!
|
|
846
|
+
"""
|
|
847
|
+
self._check_capability_or_warn("initial_label", override_capability_checks)
|
|
848
|
+
assert all(
|
|
849
|
+
[i == j for i, j in zip(self.original_image_shape[1:], initial_seg.shape)]
|
|
850
|
+
), f"Given initial seg must match input image shape. Input image was: {self.original_image_shape[1:]}, given: {initial_seg.shape}"
|
|
851
|
+
|
|
852
|
+
self._finish_preprocessing_and_initialize_interactions()
|
|
853
|
+
|
|
854
|
+
self.reset_interactions()
|
|
855
|
+
|
|
856
|
+
if isinstance(self.target_buffer, np.ndarray):
|
|
857
|
+
self.target_buffer[:] = initial_seg
|
|
858
|
+
|
|
859
|
+
initial_seg = torch.from_numpy(initial_seg)
|
|
860
|
+
|
|
861
|
+
if isinstance(self.target_buffer, torch.Tensor):
|
|
862
|
+
self.target_buffer[:] = initial_seg
|
|
863
|
+
|
|
864
|
+
# crop (as in preprocessing)
|
|
865
|
+
initial_seg = crop_and_pad_nd(initial_seg, self.preprocessed_props["bbox_used_for_cropping"])
|
|
866
|
+
|
|
867
|
+
# initial seg is written into initial seg buffer
|
|
868
|
+
interaction_channel = self._get_prev_seg_channel()
|
|
869
|
+
self._write_interactions_channel(interaction_channel, initial_seg)
|
|
870
|
+
|
|
871
|
+
empty_cache(self.device)
|
|
872
|
+
if run_prediction:
|
|
873
|
+
self._add_patch_for_initial_seg_interaction(initial_seg)
|
|
874
|
+
del initial_seg
|
|
875
|
+
self._predict(force_full_refine=True)
|
|
876
|
+
else:
|
|
877
|
+
del initial_seg
|
|
878
|
+
|
|
879
|
+
@torch.inference_mode()
|
|
880
|
+
def _predict(self, force_full_refine: bool = False):
|
|
881
|
+
"""
|
|
882
|
+
force_full_refine if True we run the refinement over the whole current prediction and not just the diff map.
|
|
883
|
+
More effort but sometimes needed (refine initial seg)
|
|
884
|
+
|
|
885
|
+
If it feels like we are excessively transferring tensors between CPU and GPU, this is deliberate.
|
|
886
|
+
Our goal is to keep this tool usable even for people with smaller GPUs (8-10GB VRAM). In an ideal world
|
|
887
|
+
everyone would have 24GB+ of VRAM and all tensors would like on GPU all the time.
|
|
888
|
+
The amount of hours spent optimizing this function is substantial. Almost every line was turned and twisted
|
|
889
|
+
multiple times. If something appears odd, it is probably so for a reason. Don't change things all willy nilly
|
|
890
|
+
without first understanding what is going on. And don't make changes without verifying that the run time or
|
|
891
|
+
VRAM consumption is not adversely affected.
|
|
892
|
+
|
|
893
|
+
Returns:
|
|
894
|
+
|
|
895
|
+
"""
|
|
896
|
+
print("Current cratio", self.interactions.cratio)
|
|
897
|
+
|
|
898
|
+
assert self.pad_mode_data == "constant", "pad modes other than constant are not implemented here"
|
|
899
|
+
assert len(self.new_interaction_centers) == len(self.new_interaction_zoom_out_factors)
|
|
900
|
+
prev_seg_channel = self._get_prev_seg_channel()
|
|
901
|
+
if len(self.new_interaction_centers) == 0:
|
|
902
|
+
print("No patch queued for prediction. Nothing to do.")
|
|
903
|
+
return
|
|
904
|
+
|
|
905
|
+
if len(self.new_interaction_centers) > 1:
|
|
906
|
+
print(
|
|
907
|
+
"It seems like more than one interaction was added since the last prediction. This is not "
|
|
908
|
+
"recommended and may cause unexpected behavior or inefficient predictions\n"
|
|
909
|
+
"!!!WE NO LONGER RUN ONE PREDICTION PER CENTER AND ONLY USE THE LAST ADDED INTERACTION AS CENTER!!!"
|
|
910
|
+
)
|
|
911
|
+
prediction_center, zoom_out_factor = self.new_interaction_centers[-1], self.new_interaction_zoom_out_factors[-1]
|
|
912
|
+
zoom_out_factor = min(4, zoom_out_factor)
|
|
913
|
+
|
|
914
|
+
start_predict = time()
|
|
915
|
+
with torch.autocast(self.device.type, enabled=True) if self.device.type == "cuda" else dummy_context():
|
|
916
|
+
# make a prediction at zoom_out_factor, remember max_zoom_out_factor
|
|
917
|
+
start_initial_pred = time()
|
|
918
|
+
input_for_predict, scaled_patch_size, scaled_bbox, previous_prediction = self._build_network_input(
|
|
919
|
+
prediction_center, zoom_out_factor
|
|
920
|
+
)
|
|
921
|
+
pred = self.network(input_for_predict[None])[0].argmax(0).detach()
|
|
922
|
+
del input_for_predict
|
|
923
|
+
|
|
924
|
+
# detect changes at border. If there are, we enter autozoom
|
|
925
|
+
has_change = self._detect_change_at_border(pred, previous_prediction)
|
|
926
|
+
del previous_prediction
|
|
927
|
+
empty_cache(self.device)
|
|
928
|
+
|
|
929
|
+
print(
|
|
930
|
+
f"Took {round(time() - start_initial_pred, 3)} s for initial prediction at zoom out factor {zoom_out_factor}"
|
|
931
|
+
)
|
|
932
|
+
|
|
933
|
+
# maybe do zoom out
|
|
934
|
+
zoom_out_growth_factor = 1.5
|
|
935
|
+
start_zoomout = time()
|
|
936
|
+
while has_change and self.do_autozoom:
|
|
937
|
+
print(f"AutoZoom zoom out factor {zoom_out_factor}")
|
|
938
|
+
# we allow a max zoom out of 4
|
|
939
|
+
if zoom_out_factor >= 4:
|
|
940
|
+
break
|
|
941
|
+
else:
|
|
942
|
+
zoom_out_factor *= zoom_out_growth_factor
|
|
943
|
+
zoom_out_factor = min(4, zoom_out_factor)
|
|
944
|
+
|
|
945
|
+
input_for_predict, scaled_patch_size, scaled_bbox, previous_prediction_resized = (
|
|
946
|
+
self._build_network_input(prediction_center, zoom_out_factor)
|
|
947
|
+
)
|
|
948
|
+
pred = self.network(input_for_predict[None])[0].argmax(0).detach()
|
|
949
|
+
del input_for_predict
|
|
950
|
+
empty_cache(self.device)
|
|
951
|
+
|
|
952
|
+
has_change = self._detect_change_at_border(pred, previous_prediction_resized)
|
|
953
|
+
|
|
954
|
+
if zoom_out_factor > 1:
|
|
955
|
+
print(f"Zoom out took {round(time() - start_zoomout, 3)} s, max zoom out factor {zoom_out_factor}")
|
|
956
|
+
else:
|
|
957
|
+
print("No zoom out necessary")
|
|
958
|
+
|
|
959
|
+
if zoom_out_factor == 1:
|
|
960
|
+
# simply place pred in the prev_seg channel and target buffer
|
|
961
|
+
paste_tensor(self.interactions, pred.half(), scaled_bbox, channel_idx=prev_seg_channel)
|
|
962
|
+
self._paste_prediction_to_target_buffer(pred, scaled_bbox)
|
|
963
|
+
print("No refinement necessary")
|
|
964
|
+
else:
|
|
965
|
+
# do refinement
|
|
966
|
+
|
|
967
|
+
if not all([i == j for i, j in zip(pred.shape, scaled_patch_size)]):
|
|
968
|
+
pred = (
|
|
969
|
+
interpolate(pred[None, None].to(torch.float32), scaled_patch_size, mode="trilinear")[0, 0]
|
|
970
|
+
>= 0.5
|
|
971
|
+
).to(torch.uint8)
|
|
972
|
+
|
|
973
|
+
refinement_bboxes = self._plan_refinement_bboxes(pred, scaled_bbox, force_full_refine)
|
|
974
|
+
|
|
975
|
+
# Place the coarse segmentation into prev_seg before refinement
|
|
976
|
+
paste_tensor(self.interactions, pred, scaled_bbox, channel_idx=prev_seg_channel)
|
|
977
|
+
|
|
978
|
+
self._refine_coarse(refinement_bboxes)
|
|
979
|
+
|
|
980
|
+
print(f"Done. Total time {round(time() - start_predict, 3)}s")
|
|
981
|
+
|
|
982
|
+
self.new_interaction_centers = []
|
|
983
|
+
self.new_interaction_zoom_out_factors = []
|
|
984
|
+
empty_cache(self.device)
|
|
985
|
+
|
|
986
|
+
def _build_network_input(self, prediction_center, zoom_out_factor):
|
|
987
|
+
scaled_patch_size = [round(i * zoom_out_factor) for i in self.configuration_manager.patch_size]
|
|
988
|
+
scaled_bbox = [[c - p // 2, c + p // 2 + p % 2] for c, p in zip(prediction_center, scaled_patch_size)]
|
|
989
|
+
prev_seg_channel = self._get_prev_seg_channel()
|
|
990
|
+
|
|
991
|
+
# cropping happens on CPU, padding happens on GPU (later)
|
|
992
|
+
crop_img, pad_image = crop_to_valid(self.preprocessed_image, scaled_bbox)
|
|
993
|
+
interactions_tensor, pad_interaction = crop_to_valid(self.interactions, scaled_bbox)
|
|
994
|
+
# For blosc2, crop_to_valid returns a numpy array; convert to torch (still on CPU).
|
|
995
|
+
if not isinstance(interactions_tensor, torch.Tensor):
|
|
996
|
+
interactions_tensor = torch.from_numpy(np.asarray(interactions_tensor))
|
|
997
|
+
|
|
998
|
+
previous_prediction = interactions_tensor[prev_seg_channel : prev_seg_channel + 1]
|
|
999
|
+
|
|
1000
|
+
# resize input_for_predict (which may be larger than patch size) to patch size
|
|
1001
|
+
# this implementation may not seem straightforward but it does save VRAM which is crucial here
|
|
1002
|
+
if not all([i == j for i, j in zip(self.configuration_manager.patch_size, scaled_patch_size)]):
|
|
1003
|
+
patch_size = self.configuration_manager.patch_size
|
|
1004
|
+
max_pool_ks = round_to_nearest_odd(zoom_out_factor * 2 - 1)
|
|
1005
|
+
dilation_channels = set(self._get_dilation_channels_for_resample()) if max_pool_ks > 1 else set()
|
|
1006
|
+
needs_pad_interaction = any(x for pair in pad_interaction for x in pair)
|
|
1007
|
+
|
|
1008
|
+
previous_prediction = previous_prediction.to(self.device, non_blocking=True)
|
|
1009
|
+
if needs_pad_interaction:
|
|
1010
|
+
previous_prediction = pad_cropped(previous_prediction, pad_interaction)
|
|
1011
|
+
previous_prediction = interpolate(previous_prediction[None], patch_size, mode="nearest")[0, 0]
|
|
1012
|
+
|
|
1013
|
+
# Process interaction channels one at a time to avoid materialising the full
|
|
1014
|
+
# [num_ch, scaled_patch_size³] tensor on GPU. Peak VRAM ≈ one channel at scaled size.
|
|
1015
|
+
num_interaction_ch = interactions_tensor.shape[0]
|
|
1016
|
+
interactions_out = torch.empty(
|
|
1017
|
+
[num_interaction_ch, *patch_size], dtype=interactions_tensor.dtype, device=self.device
|
|
1018
|
+
)
|
|
1019
|
+
for i in range(num_interaction_ch):
|
|
1020
|
+
ch = interactions_tensor[i : i + 1].to(self.device, non_blocking=True)
|
|
1021
|
+
if needs_pad_interaction:
|
|
1022
|
+
ch = pad_cropped(ch, pad_interaction)
|
|
1023
|
+
if i in dilation_channels:
|
|
1024
|
+
ch = iterative_3x3_same_padding_pool3d(ch[None], max_pool_ks)[0]
|
|
1025
|
+
interactions_out[i : i + 1] = interpolate(ch[None], patch_size, mode="area")[0]
|
|
1026
|
+
del ch
|
|
1027
|
+
del interactions_tensor
|
|
1028
|
+
interactions_tensor = interactions_out
|
|
1029
|
+
|
|
1030
|
+
# Keep image and interaction tensors in identical spatial frames before concatenation.
|
|
1031
|
+
# Interactions use area downsampling (with selective dilation beforehand), image uses trilinear.
|
|
1032
|
+
crop_img = crop_img.to(self.device, non_blocking=True)
|
|
1033
|
+
if any(x for pair in pad_image for x in pair):
|
|
1034
|
+
crop_img = pad_cropped(crop_img, pad_image)
|
|
1035
|
+
crop_img = interpolate(crop_img[None], patch_size, mode="trilinear")[0]
|
|
1036
|
+
|
|
1037
|
+
empty_cache(self.device)
|
|
1038
|
+
else:
|
|
1039
|
+
# zoom_out_factor == 1: transfer both tensors to GPU, then pad if needed
|
|
1040
|
+
crop_img = crop_img.to(self.device, non_blocking=True)
|
|
1041
|
+
interactions_tensor = interactions_tensor.to(self.device, non_blocking=True)
|
|
1042
|
+
previous_prediction = previous_prediction.to(self.device, non_blocking=True)
|
|
1043
|
+
if any(x for pair in pad_image for x in pair):
|
|
1044
|
+
crop_img = pad_cropped(crop_img, pad_image)
|
|
1045
|
+
if any(x for pair in pad_interaction for x in pair):
|
|
1046
|
+
interactions_tensor = pad_cropped(interactions_tensor, pad_interaction)
|
|
1047
|
+
previous_prediction = pad_cropped(previous_prediction, pad_interaction)
|
|
1048
|
+
previous_prediction = previous_prediction[0]
|
|
1049
|
+
|
|
1050
|
+
self._normalize_interaction_channels_for_network_(interactions_tensor)
|
|
1051
|
+
input_for_predict = torch.cat((crop_img, interactions_tensor))
|
|
1052
|
+
del crop_img, interactions_tensor
|
|
1053
|
+
empty_cache(self.device)
|
|
1054
|
+
return input_for_predict, scaled_patch_size, scaled_bbox, previous_prediction
|
|
1055
|
+
|
|
1056
|
+
def _refine_coarse(self, bboxes_ordered: List[List[List[int]]]):
|
|
1057
|
+
start_refinement = time()
|
|
1058
|
+
prev_seg_channel = self._get_prev_seg_channel()
|
|
1059
|
+
|
|
1060
|
+
if self.verbose:
|
|
1061
|
+
print(f"Using {len(bboxes_ordered)} bounding boxes for refinement")
|
|
1062
|
+
|
|
1063
|
+
self._refine_coarse_with_local_cache(bboxes_ordered, prev_seg_channel)
|
|
1064
|
+
end_refinement = time()
|
|
1065
|
+
print(
|
|
1066
|
+
f"Took {round(end_refinement - start_refinement, 3)} s for refining the segmentation with {len(bboxes_ordered)} bounding boxes"
|
|
1067
|
+
)
|
|
1068
|
+
|
|
1069
|
+
def _refine_coarse_with_local_cache(self, bboxes_ordered: List[List[List[int]]], prev_seg_channel: int) -> None:
|
|
1070
|
+
cache_bbox, cache_image, cache_interactions = self._build_refinement_local_cache(bboxes_ordered)
|
|
1071
|
+
|
|
1072
|
+
for refinement_bbox in bboxes_ordered:
|
|
1073
|
+
local_bbox = [
|
|
1074
|
+
[lb - cache_dim[0], ub - cache_dim[0]] for (lb, ub), cache_dim in zip(refinement_bbox, cache_bbox)
|
|
1075
|
+
]
|
|
1076
|
+
spatial_slicer = tuple(slice(lb, ub) for lb, ub in local_bbox)
|
|
1077
|
+
image_patch = cache_image[spatial_slicer][None]
|
|
1078
|
+
interactions_patch = cache_interactions[(slice(None), *spatial_slicer)]
|
|
1079
|
+
if cache_image.device == self.device:
|
|
1080
|
+
patch = torch.cat((image_patch, interactions_patch), dim=0)
|
|
1081
|
+
else:
|
|
1082
|
+
patch = torch.cat(
|
|
1083
|
+
(
|
|
1084
|
+
image_patch.to(self.device, non_blocking=(self.device.type == "cuda")),
|
|
1085
|
+
interactions_patch.to(self.device, non_blocking=(self.device.type == "cuda")),
|
|
1086
|
+
),
|
|
1087
|
+
dim=0,
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
pred = self.network(patch[None])[0].argmax(0).detach()
|
|
1091
|
+
paste_tensor(
|
|
1092
|
+
cache_interactions,
|
|
1093
|
+
pred.to(cache_interactions.device, dtype=cache_interactions.dtype),
|
|
1094
|
+
local_bbox,
|
|
1095
|
+
channel_idx=prev_seg_channel,
|
|
1096
|
+
)
|
|
1097
|
+
del image_patch, interactions_patch, patch
|
|
1098
|
+
del pred
|
|
1099
|
+
|
|
1100
|
+
final_prev_seg = cache_interactions[prev_seg_channel]
|
|
1101
|
+
paste_tensor(self.interactions, final_prev_seg, cache_bbox, channel_idx=prev_seg_channel)
|
|
1102
|
+
self._paste_prediction_to_target_buffer(final_prev_seg, cache_bbox)
|
|
1103
|
+
|
|
1104
|
+
del cache_image, cache_interactions, final_prev_seg
|
|
1105
|
+
empty_cache(self.device)
|
|
1106
|
+
|
|
1107
|
+
def _detect_change_at_border(
|
|
1108
|
+
self,
|
|
1109
|
+
pred: torch.Tensor,
|
|
1110
|
+
prev_pred: torch.Tensor,
|
|
1111
|
+
abs_pxl_change_threshold=1500,
|
|
1112
|
+
rel_pxl_change_threshold=0.2,
|
|
1113
|
+
min_pxl_change_threshold=100,
|
|
1114
|
+
):
|
|
1115
|
+
has_change: bool = False
|
|
1116
|
+
for dim in range(pred.ndim):
|
|
1117
|
+
if has_change:
|
|
1118
|
+
break
|
|
1119
|
+
for idx in [0, pred.shape[dim] - 1]:
|
|
1120
|
+
slice_prev = prev_pred.index_select(dim, torch.tensor(idx, device=prev_pred.device))
|
|
1121
|
+
slice_curr = pred.index_select(dim, torch.tensor(idx, device=self.device)).to(prev_pred.device)
|
|
1122
|
+
pixels_prev = torch.sum(slice_prev)
|
|
1123
|
+
pixels_current = torch.sum(slice_curr)
|
|
1124
|
+
pixels_diff = torch.sum(slice_prev != slice_curr)
|
|
1125
|
+
rel_change = max(pixels_prev, pixels_current) / max(min(pixels_prev, pixels_current), 1e-5) - 1
|
|
1126
|
+
if pixels_diff > abs_pxl_change_threshold:
|
|
1127
|
+
has_change = True
|
|
1128
|
+
if self.verbose:
|
|
1129
|
+
print(
|
|
1130
|
+
f"continue zooming because change at borders of {pixels_diff} > {abs_pxl_change_threshold}"
|
|
1131
|
+
)
|
|
1132
|
+
break
|
|
1133
|
+
if pixels_diff > min_pxl_change_threshold and rel_change > rel_pxl_change_threshold:
|
|
1134
|
+
has_change = True
|
|
1135
|
+
if self.verbose:
|
|
1136
|
+
print(
|
|
1137
|
+
f"continue zooming because relative change of {rel_change} > {rel_pxl_change_threshold} and n_pixels {pixels_diff} > {min_pxl_change_threshold}"
|
|
1138
|
+
)
|
|
1139
|
+
break
|
|
1140
|
+
del slice_prev, slice_curr, pixels_prev, pixels_current, pixels_diff
|
|
1141
|
+
return has_change
|
|
1142
|
+
|
|
1143
|
+
def _compute_local_diff_map(
|
|
1144
|
+
self, pred: torch.Tensor, scaled_bbox: List[List[int]], planning_bbox: List[List[int]]
|
|
1145
|
+
) -> torch.Tensor:
|
|
1146
|
+
"""
|
|
1147
|
+
Compute a local diff map inside planning_bbox only.
|
|
1148
|
+
|
|
1149
|
+
pred is expected to be the coarse prediction resized to match scaled_bbox.
|
|
1150
|
+
planning_bbox is in global interaction coordinates and may be larger than scaled_bbox when
|
|
1151
|
+
force_full_refine expands the refinement planning ROI.
|
|
1152
|
+
"""
|
|
1153
|
+
prev_seg_ch = self._get_prev_seg_channel()
|
|
1154
|
+
spatial_shape = tuple(int(i) for i in self.interactions.shape[1:])
|
|
1155
|
+
seen_bbox = self._clip_bbox_to_shape(scaled_bbox, spatial_shape)
|
|
1156
|
+
planning_bbox = self._clip_bbox_to_shape(planning_bbox, spatial_shape)
|
|
1157
|
+
if seen_bbox is None or planning_bbox is None:
|
|
1158
|
+
return torch.zeros((0, 0, 0), device=self.device, dtype=torch.uint8)
|
|
1159
|
+
|
|
1160
|
+
local_shape = self._bbox_size(planning_bbox)
|
|
1161
|
+
diff_local = torch.zeros(local_shape, device=self.device, dtype=torch.float16)
|
|
1162
|
+
|
|
1163
|
+
pred_bbox = [
|
|
1164
|
+
[seen_dim[0] - scaled_dim[0], seen_dim[1] - scaled_dim[0]]
|
|
1165
|
+
for seen_dim, scaled_dim in zip(seen_bbox, scaled_bbox)
|
|
1166
|
+
]
|
|
1167
|
+
pred_bbox = [[max(0, lb), min(ub, int(pred.shape[dim]))] for dim, (lb, ub) in enumerate(pred_bbox)]
|
|
1168
|
+
local_seen_bbox = [
|
|
1169
|
+
[seen_dim[0] - planning_dim[0], seen_dim[1] - planning_dim[0]]
|
|
1170
|
+
for seen_dim, planning_dim in zip(seen_bbox, planning_bbox)
|
|
1171
|
+
]
|
|
1172
|
+
|
|
1173
|
+
seen_slicer = tuple(slice(lb, ub) for lb, ub in seen_bbox)
|
|
1174
|
+
pred_slicer = tuple(slice(lb, ub) for lb, ub in pred_bbox)
|
|
1175
|
+
local_slicer = tuple(slice(lb, ub) for lb, ub in local_seen_bbox)
|
|
1176
|
+
|
|
1177
|
+
prev_sub = torch.from_numpy(np.asarray(self.interactions[(prev_seg_ch, *seen_slicer)])).to(self.device)
|
|
1178
|
+
|
|
1179
|
+
diff_local[local_slicer] = (pred[pred_slicer] != prev_sub).to(diff_local.dtype)
|
|
1180
|
+
del prev_sub
|
|
1181
|
+
|
|
1182
|
+
# Open/close the local difference map to reduce the number of refinement patches without materializing
|
|
1183
|
+
# a full-image planning tensor.
|
|
1184
|
+
diff_local[local_slicer] = iterative_3x3_same_padding_pool3d(
|
|
1185
|
+
diff_local[local_slicer][None, None], kernel_size=5, use_min_pool=True
|
|
1186
|
+
)[0, 0]
|
|
1187
|
+
diff_local[local_slicer] = iterative_3x3_same_padding_pool3d(
|
|
1188
|
+
diff_local[local_slicer][None, None], kernel_size=5, use_min_pool=False
|
|
1189
|
+
)[0, 0]
|
|
1190
|
+
|
|
1191
|
+
return diff_local.to(torch.uint8)
|
|
1192
|
+
|
|
1193
|
+
def _mark_prev_seg_in_local_diff(self, diff_local: torch.Tensor, planning_bbox: List[List[int]]) -> None:
|
|
1194
|
+
prev_seg_ch = self._get_prev_seg_channel()
|
|
1195
|
+
planning_slicer = tuple(slice(lb, ub) for lb, ub in planning_bbox)
|
|
1196
|
+
prev_sub = torch.from_numpy(np.asarray(self.interactions[(prev_seg_ch, *planning_slicer)])).to(self.device)
|
|
1197
|
+
diff_local[prev_sub > 0.5] = 1
|
|
1198
|
+
del prev_sub
|
|
1199
|
+
|
|
1200
|
+
def _plan_refinement_bboxes(
|
|
1201
|
+
self, pred: torch.Tensor, scaled_bbox: List[List[int]], force_full_refine: bool
|
|
1202
|
+
) -> List[List[List[int]]]:
|
|
1203
|
+
spatial_shape = tuple(int(i) for i in self.interactions.shape[1:])
|
|
1204
|
+
planning_bbox = self._clip_bbox_to_shape(scaled_bbox, spatial_shape)
|
|
1205
|
+
|
|
1206
|
+
if force_full_refine:
|
|
1207
|
+
print("Forcing full refinement of entire structure")
|
|
1208
|
+
prev_seg_bbox = self._compute_prev_seg_positive_bbox()
|
|
1209
|
+
planning_bbox = self._union_bboxes(planning_bbox, prev_seg_bbox)
|
|
1210
|
+
|
|
1211
|
+
if planning_bbox is None:
|
|
1212
|
+
center = self.new_interaction_centers[-1]
|
|
1213
|
+
return [
|
|
1214
|
+
[[ci - pi // 2, ci - pi // 2 + pi] for ci, pi in zip(center, self.configuration_manager.patch_size)]
|
|
1215
|
+
]
|
|
1216
|
+
|
|
1217
|
+
diff_local = self._compute_local_diff_map(pred, scaled_bbox, planning_bbox)
|
|
1218
|
+
if force_full_refine:
|
|
1219
|
+
self._mark_prev_seg_in_local_diff(diff_local, planning_bbox)
|
|
1220
|
+
|
|
1221
|
+
local_bboxes = generate_bounding_boxes(
|
|
1222
|
+
diff_local, self.configuration_manager.patch_size, stride="auto", margin=(24, 24, 24), max_depth=3
|
|
1223
|
+
)
|
|
1224
|
+
del diff_local
|
|
1225
|
+
empty_cache(self.device)
|
|
1226
|
+
|
|
1227
|
+
# If no bounding boxes are returned we basically have almost no changes. Still we should at least perform
|
|
1228
|
+
# refinement in the bounding box where the interaction was as the user evidently wanted something here.
|
|
1229
|
+
if len(local_bboxes) == 0:
|
|
1230
|
+
center = self.new_interaction_centers[-1]
|
|
1231
|
+
return [
|
|
1232
|
+
[[ci - pi // 2, ci - pi // 2 + pi] for ci, pi in zip(center, self.configuration_manager.patch_size)]
|
|
1233
|
+
]
|
|
1234
|
+
|
|
1235
|
+
return self._offset_bboxes(local_bboxes, planning_bbox)
|
|
1236
|
+
|
|
1237
|
+
def _add_patch_for_point_interaction(self, coordinates):
|
|
1238
|
+
self.new_interaction_zoom_out_factors.append(1)
|
|
1239
|
+
self.new_interaction_centers.append(coordinates)
|
|
1240
|
+
print(
|
|
1241
|
+
f"Added new point interaction: center {self.new_interaction_zoom_out_factors[-1]}, scale {self.new_interaction_centers}"
|
|
1242
|
+
)
|
|
1243
|
+
|
|
1244
|
+
def _add_patch_for_bbox_interaction(self, bbox):
|
|
1245
|
+
bbox_center = [round((i[0] + i[1]) / 2) for i in bbox]
|
|
1246
|
+
bbox_size = [i[1] - i[0] for i in bbox]
|
|
1247
|
+
# we want to see some context, so the crop we see for the initial prediction should be patch_size / 3 larger
|
|
1248
|
+
requested_size = [i + j // 3 for i, j in zip(bbox_size, self.configuration_manager.patch_size)]
|
|
1249
|
+
self.new_interaction_zoom_out_factors.append(
|
|
1250
|
+
max(1, max([i / j for i, j in zip(requested_size, self.configuration_manager.patch_size)]))
|
|
1251
|
+
)
|
|
1252
|
+
self.new_interaction_centers.append(bbox_center)
|
|
1253
|
+
print(
|
|
1254
|
+
f"Added new bbox interaction: center {self.new_interaction_zoom_out_factors[-1]}, scale {self.new_interaction_centers}"
|
|
1255
|
+
)
|
|
1256
|
+
|
|
1257
|
+
def _add_patch_for_initial_seg_interaction(self, initial_seg):
|
|
1258
|
+
return self._generic_add_patch_from_image(initial_seg)
|
|
1259
|
+
|
|
1260
|
+
def _generic_add_patch_from_image(self, image: torch.Tensor, offset: Optional[List[int]] = None):
|
|
1261
|
+
if not torch.any(image):
|
|
1262
|
+
print("Received empty image prompt. Cannot add patches for prediction")
|
|
1263
|
+
return
|
|
1264
|
+
if offset is None:
|
|
1265
|
+
offset = [0] * image.ndim
|
|
1266
|
+
nonzero_indices = torch.nonzero(image, as_tuple=False)
|
|
1267
|
+
mn = torch.min(nonzero_indices, dim=0)[0]
|
|
1268
|
+
mx = torch.max(nonzero_indices, dim=0)[0]
|
|
1269
|
+
roi = [[i.item() + off, x.item() + off + 1] for i, x, off in zip(mn, mx, offset)]
|
|
1270
|
+
roi_center = [round((i[0] + i[1]) / 2) for i in roi]
|
|
1271
|
+
roi_size = [i[1] - i[0] for i in roi]
|
|
1272
|
+
requested_size = [i + j // 3 for i, j in zip(roi_size, self.configuration_manager.patch_size)]
|
|
1273
|
+
self.new_interaction_zoom_out_factors.append(
|
|
1274
|
+
max(1, max([i / j for i, j in zip(requested_size, self.configuration_manager.patch_size)]))
|
|
1275
|
+
)
|
|
1276
|
+
self.new_interaction_centers.append(roi_center)
|
|
1277
|
+
print(
|
|
1278
|
+
f"Added new image interaction: scale {self.new_interaction_zoom_out_factors[-1]}, center {self.new_interaction_centers}"
|
|
1279
|
+
)
|
|
1280
|
+
|
|
1281
|
+
def initialize_from_trained_model_folder(
|
|
1282
|
+
self,
|
|
1283
|
+
model_training_output_dir: str,
|
|
1284
|
+
use_fold: Union[int, str] = None,
|
|
1285
|
+
checkpoint_name: str = "checkpoint_final.pth",
|
|
1286
|
+
):
|
|
1287
|
+
"""
|
|
1288
|
+
This is used when making predictions with a trained model
|
|
1289
|
+
"""
|
|
1290
|
+
point_interaction_use_etd = True
|
|
1291
|
+
(
|
|
1292
|
+
capability_content,
|
|
1293
|
+
point_interaction_radius,
|
|
1294
|
+
self.preferred_scribble_thickness,
|
|
1295
|
+
self.interaction_decay,
|
|
1296
|
+
self.pad_mode_data,
|
|
1297
|
+
) = self._load_capability_and_runtime_defaults(model_training_output_dir)
|
|
1298
|
+
|
|
1299
|
+
self.point_interaction = PointInteraction_stub(point_interaction_radius, point_interaction_use_etd)
|
|
1300
|
+
self._apply_capability(capability_content)
|
|
1301
|
+
|
|
1302
|
+
dataset_json = load_json(join(model_training_output_dir, "dataset.json"))
|
|
1303
|
+
plans = load_json(join(model_training_output_dir, "plans.json"))
|
|
1304
|
+
plans_manager = PlansManager(plans)
|
|
1305
|
+
|
|
1306
|
+
if use_fold is not None:
|
|
1307
|
+
use_fold = int(use_fold) if use_fold != "all" else use_fold
|
|
1308
|
+
fold_folder = f"fold_{use_fold}"
|
|
1309
|
+
else:
|
|
1310
|
+
fldrs = subdirs(model_training_output_dir, prefix="fold_", join=False)
|
|
1311
|
+
assert len(fldrs) == 1, f"Attempted to infer fold but there is != 1 fold_ folders: {fldrs}"
|
|
1312
|
+
fold_folder = fldrs[0]
|
|
1313
|
+
|
|
1314
|
+
checkpoint = torch.load(
|
|
1315
|
+
join(model_training_output_dir, fold_folder, checkpoint_name), map_location=self.device, weights_only=False
|
|
1316
|
+
)
|
|
1317
|
+
if self._is_official_checkpoint(plans, checkpoint):
|
|
1318
|
+
print(
|
|
1319
|
+
"License reminder: The official nnInteractive checkpoint is licensed under "
|
|
1320
|
+
"Creative Commons Attribution Non Commercial Share Alike 4.0 (CC BY-NC-SA 4.0). "
|
|
1321
|
+
"See the license note in readme.md (# License)."
|
|
1322
|
+
)
|
|
1323
|
+
trainer_name = checkpoint["trainer_name"]
|
|
1324
|
+
configuration_name = checkpoint["init_args"]["configuration"]
|
|
1325
|
+
|
|
1326
|
+
parameters = checkpoint["network_weights"]
|
|
1327
|
+
|
|
1328
|
+
configuration_manager = plans_manager.get_configuration(configuration_name)
|
|
1329
|
+
# restore network
|
|
1330
|
+
num_input_channels = (
|
|
1331
|
+
determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
|
|
1332
|
+
+ self.num_interaction_channels
|
|
1333
|
+
)
|
|
1334
|
+
trainer_class = recursive_find_python_class(
|
|
1335
|
+
join(nnInteractive.__path__[0], "trainer"), trainer_name, "nnInteractive.trainer"
|
|
1336
|
+
)
|
|
1337
|
+
if trainer_class is None:
|
|
1338
|
+
print(
|
|
1339
|
+
f"Unable to locate trainer class {trainer_name} in nnInteractive.trainer. "
|
|
1340
|
+
f"Please place it there (in any .py file)!"
|
|
1341
|
+
)
|
|
1342
|
+
print(
|
|
1343
|
+
"Attempting to use default nnInteractiveTrainer_stub. If you encounter errors, this is where you need to look!"
|
|
1344
|
+
)
|
|
1345
|
+
trainer_class = nnInteractiveTrainer_stub
|
|
1346
|
+
|
|
1347
|
+
network = trainer_class.build_network_architecture(
|
|
1348
|
+
plans_manager,
|
|
1349
|
+
configuration_manager,
|
|
1350
|
+
num_input_channels,
|
|
1351
|
+
plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
|
|
1352
|
+
enable_deep_supervision=False,
|
|
1353
|
+
).to(self.device)
|
|
1354
|
+
network.load_state_dict(parameters)
|
|
1355
|
+
|
|
1356
|
+
self.plans_manager = plans_manager
|
|
1357
|
+
self.configuration_manager = configuration_manager
|
|
1358
|
+
self.network = network
|
|
1359
|
+
self.dataset_json = dataset_json
|
|
1360
|
+
self.trainer_name = trainer_name
|
|
1361
|
+
self.label_manager = plans_manager.get_label_manager(dataset_json)
|
|
1362
|
+
if self.use_torch_compile and not isinstance(self.network, OptimizedModule):
|
|
1363
|
+
print("Using torch.compile")
|
|
1364
|
+
self.network = torch.compile(self.network)
|
|
1365
|
+
|
|
1366
|
+
def manual_initialization(
|
|
1367
|
+
self,
|
|
1368
|
+
network: nn.Module,
|
|
1369
|
+
plans_manager: PlansManager,
|
|
1370
|
+
configuration_manager: ConfigurationManager,
|
|
1371
|
+
dataset_json: dict,
|
|
1372
|
+
trainer_name: str,
|
|
1373
|
+
):
|
|
1374
|
+
"""
|
|
1375
|
+
This is used by the nnUNetTrainer to initialize nnUNetPredictor for the final validation
|
|
1376
|
+
"""
|
|
1377
|
+
self.plans_manager = plans_manager
|
|
1378
|
+
self.configuration_manager = configuration_manager
|
|
1379
|
+
self.network = network.to(self.device)
|
|
1380
|
+
self.dataset_json = dataset_json
|
|
1381
|
+
self.trainer_name = trainer_name
|
|
1382
|
+
self.label_manager = plans_manager.get_label_manager(dataset_json)
|
|
1383
|
+
|
|
1384
|
+
if self.use_torch_compile and not isinstance(self.network, OptimizedModule):
|
|
1385
|
+
print("Using torch.compile")
|
|
1386
|
+
self.network = torch.compile(self.network)
|
|
1387
|
+
|
|
1388
|
+
if not self.use_torch_compile and isinstance(self.network, OptimizedModule):
|
|
1389
|
+
self.network = self.network._orig_mod
|
|
1390
|
+
|
|
1391
|
+
self.network = self.network.to(self.device)
|
|
1392
|
+
|
|
1393
|
+
def __del__(self):
|
|
1394
|
+
self._finish_preprocessing_and_initialize_interactions()
|
|
1395
|
+
self.executor.shutdown()
|
|
1396
|
+
|
|
1397
|
+
|
|
1398
|
+
if __name__ == "__main__":
|
|
1399
|
+
a = torch.zeros((160, 160, 160), device="cpu")
|
|
1400
|
+
a.index_select(0, torch.tensor([0]))
|