monai-weekly 1.4.dev2430__py3-none-any.whl → 1.4.dev2434__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
- monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
- monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
- monai/bundle/config_parser.py +2 -2
- monai/bundle/reference_resolver.py +18 -1
- monai/bundle/scripts.py +45 -22
- monai/bundle/utils.py +3 -1
- monai/data/utils.py +1 -1
- monai/data/wsi_datasets.py +3 -3
- monai/losses/__init__.py +1 -0
- monai/losses/dice.py +10 -1
- monai/losses/nacl_loss.py +139 -0
- monai/networks/blocks/crossattention.py +48 -26
- monai/networks/blocks/mlp.py +16 -4
- monai/networks/blocks/selfattention.py +75 -23
- monai/networks/blocks/spatialattention.py +16 -1
- monai/networks/blocks/transformerblock.py +17 -2
- monai/networks/nets/__init__.py +2 -1
- monai/networks/nets/autoencoderkl.py +55 -22
- monai/networks/nets/cell_sam_wrapper.py +92 -0
- monai/networks/nets/controlnet.py +24 -22
- monai/networks/nets/diffusion_model_unet.py +159 -19
- monai/networks/nets/segresnet_ds.py +127 -1
- monai/networks/nets/spade_autoencoderkl.py +24 -2
- monai/networks/nets/spade_diffusion_model_unet.py +39 -2
- monai/networks/nets/transformer.py +17 -17
- monai/networks/nets/vista3d.py +908 -0
- monai/networks/utils.py +3 -3
- monai/transforms/__init__.py +1 -0
- monai/transforms/io/array.py +1 -1
- monai/transforms/post/array.py +2 -1
- monai/transforms/spatial/functional.py +1 -1
- monai/transforms/transform.py +2 -2
- monai/transforms/utils.py +183 -0
- monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
- monai/transforms/utils_pytorch_numpy_unification.py +2 -2
- monai/utils/module.py +7 -6
- {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/METADATA +83 -81
- {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/RECORD +44 -41
- {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,908 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import math
|
15
|
+
from typing import Any, Callable, Optional, Sequence, Tuple
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
import torch.nn.functional as F
|
20
|
+
from torch import nn
|
21
|
+
|
22
|
+
import monai
|
23
|
+
from monai.networks.blocks import MLPBlock, UnetrBasicBlock
|
24
|
+
from monai.networks.nets import SegResNetDS2
|
25
|
+
from monai.transforms.utils import convert_points_to_disc
|
26
|
+
from monai.transforms.utils import get_largest_connected_component_mask_point as lcc
|
27
|
+
from monai.transforms.utils import sample_points_from_label
|
28
|
+
from monai.utils import optional_import, unsqueeze_left, unsqueeze_right
|
29
|
+
|
30
|
+
rearrange, _ = optional_import("einops", name="rearrange")
|
31
|
+
|
32
|
+
__all__ = ["VISTA3D", "vista3d132"]
|
33
|
+
|
34
|
+
|
35
|
+
def vista3d132(encoder_embed_dim: int = 48, in_channels: int = 1):
|
36
|
+
"""
|
37
|
+
Exact VISTA3D network configuration used in https://arxiv.org/abs/2406.05285>`_.
|
38
|
+
The model treats class index larger than 132 as zero-shot.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
encoder_embed_dim: hidden dimension for encoder.
|
42
|
+
in_channels: input channel number.
|
43
|
+
"""
|
44
|
+
segresnet = SegResNetDS2(
|
45
|
+
in_channels=in_channels,
|
46
|
+
blocks_down=(1, 2, 2, 4, 4),
|
47
|
+
norm="instance",
|
48
|
+
out_channels=encoder_embed_dim,
|
49
|
+
init_filters=encoder_embed_dim,
|
50
|
+
dsdepth=1,
|
51
|
+
)
|
52
|
+
point_head = PointMappingSAM(feature_size=encoder_embed_dim, n_classes=512, last_supported=132)
|
53
|
+
class_head = ClassMappingClassify(n_classes=512, feature_size=encoder_embed_dim, use_mlp=True)
|
54
|
+
vista = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head)
|
55
|
+
return vista
|
56
|
+
|
57
|
+
|
58
|
+
class VISTA3D(nn.Module):
|
59
|
+
"""
|
60
|
+
VISTA3D based on:
|
61
|
+
`VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography
|
62
|
+
<https://arxiv.org/abs/2406.05285>`_.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
image_encoder: image encoder backbone for feature extraction.
|
66
|
+
class_head: class head used for class index based segmentation
|
67
|
+
point_head: point head used for interactive segmetnation
|
68
|
+
"""
|
69
|
+
|
70
|
+
def __init__(self, image_encoder: nn.Module, class_head: nn.Module, point_head: nn.Module):
|
71
|
+
super().__init__()
|
72
|
+
self.image_encoder = image_encoder
|
73
|
+
self.class_head = class_head
|
74
|
+
self.point_head = point_head
|
75
|
+
self.image_embeddings = None
|
76
|
+
self.auto_freeze = False
|
77
|
+
self.point_freeze = False
|
78
|
+
self.NINF_VALUE = -9999
|
79
|
+
self.PINF_VALUE = 9999
|
80
|
+
|
81
|
+
def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int:
|
82
|
+
"""Get number of foreground classes based on class and point prompt."""
|
83
|
+
if class_vector is None:
|
84
|
+
if point_coords is None:
|
85
|
+
raise ValueError("class_vector and point_coords cannot be both None.")
|
86
|
+
return point_coords.shape[0]
|
87
|
+
else:
|
88
|
+
return class_vector.shape[0]
|
89
|
+
|
90
|
+
def convert_point_label(
|
91
|
+
self,
|
92
|
+
point_label: torch.Tensor,
|
93
|
+
label_set: Sequence[int] | None = None,
|
94
|
+
special_index: Sequence[int] = (23, 24, 25, 26, 27, 57, 128),
|
95
|
+
):
|
96
|
+
"""
|
97
|
+
Convert point label based on its class prompt. For special classes defined in special index,
|
98
|
+
the positive/negative point label will be converted from 1/0 to 3/2. The purpose is to separate those
|
99
|
+
classes with ambiguous classes.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
point_label: the point label tensor, [B, N].
|
103
|
+
label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,
|
104
|
+
this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot
|
105
|
+
evaluation, this label_set should be the original index.
|
106
|
+
special_index: the special class index that needs to be converted.
|
107
|
+
"""
|
108
|
+
if label_set is None:
|
109
|
+
return point_label
|
110
|
+
if not point_label.shape[0] == len(label_set):
|
111
|
+
raise ValueError("point_label and label_set must have the same length.")
|
112
|
+
|
113
|
+
for i in range(len(label_set)):
|
114
|
+
if label_set[i] in special_index:
|
115
|
+
for j in range(len(point_label[i])):
|
116
|
+
point_label[i, j] = point_label[i, j] + 2 if point_label[i, j] > -1 else point_label[i, j]
|
117
|
+
return point_label
|
118
|
+
|
119
|
+
def sample_points_patch_val(
|
120
|
+
self,
|
121
|
+
labels: torch.Tensor,
|
122
|
+
patch_coords: Sequence[slice],
|
123
|
+
label_set: Sequence[int],
|
124
|
+
use_center: bool = True,
|
125
|
+
mapped_label_set: Sequence[int] | None = None,
|
126
|
+
max_ppoint: int = 1,
|
127
|
+
max_npoint: int = 0,
|
128
|
+
):
|
129
|
+
"""
|
130
|
+
Sample points for patch during sliding window validation. Only used for point only validation.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
labels: shape [1, 1, H, W, D].
|
134
|
+
patch_coords: a sequence of sliding window slice objects.
|
135
|
+
label_set: local index, must match values in labels.
|
136
|
+
use_center: sample points from the center.
|
137
|
+
mapped_label_set: global index, it is used to identify special classes and is the global index
|
138
|
+
for the sampled points.
|
139
|
+
max_ppoint/max_npoint: positive points and negative points to sample.
|
140
|
+
"""
|
141
|
+
point_coords, point_labels = sample_points_from_label(
|
142
|
+
labels[patch_coords],
|
143
|
+
label_set,
|
144
|
+
max_ppoint=max_ppoint,
|
145
|
+
max_npoint=max_npoint,
|
146
|
+
device=labels.device,
|
147
|
+
use_center=use_center,
|
148
|
+
)
|
149
|
+
point_labels = self.convert_point_label(point_labels, mapped_label_set)
|
150
|
+
return (point_coords, point_labels, torch.tensor(label_set).to(point_coords.device).unsqueeze(-1))
|
151
|
+
|
152
|
+
def update_point_to_patch(
|
153
|
+
self, patch_coords: Sequence[slice], point_coords: torch.Tensor, point_labels: torch.Tensor
|
154
|
+
):
|
155
|
+
"""
|
156
|
+
Update point_coords with respect to patch coords.
|
157
|
+
If point is outside of the patch, remove the coordinates and set label to -1.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference.
|
161
|
+
This value is passed from sliding_window_inferer.
|
162
|
+
point_coords: point coordinates, [B, N, 3].
|
163
|
+
point_labels: point labels, [B, N].
|
164
|
+
"""
|
165
|
+
patch_ends = [patch_coords[-3].stop, patch_coords[-2].stop, patch_coords[-1].stop]
|
166
|
+
patch_starts = [patch_coords[-3].start, patch_coords[-2].start, patch_coords[-1].start]
|
167
|
+
# update point coords
|
168
|
+
patch_starts_tensor = unsqueeze_left(torch.tensor(patch_starts, device=point_coords.device), 2)
|
169
|
+
patch_ends_tensor = unsqueeze_left(torch.tensor(patch_ends, device=point_coords.device), 2)
|
170
|
+
# [1 N 1]
|
171
|
+
indices = torch.logical_and(
|
172
|
+
((point_coords - patch_starts_tensor) > 0).all(2), ((patch_ends_tensor - point_coords) > 0).all(2)
|
173
|
+
)
|
174
|
+
# check if it's within patch coords
|
175
|
+
point_coords = point_coords.clone() - patch_starts_tensor
|
176
|
+
point_labels = point_labels.clone()
|
177
|
+
if indices.any():
|
178
|
+
point_labels[~indices] = -1
|
179
|
+
point_coords[~indices] = 0
|
180
|
+
# also remove padded points, mainly used for inference.
|
181
|
+
not_pad_indices = (point_labels != -1).any(0)
|
182
|
+
point_coords = point_coords[:, not_pad_indices]
|
183
|
+
point_labels = point_labels[:, not_pad_indices]
|
184
|
+
return point_coords, point_labels
|
185
|
+
return None, None
|
186
|
+
|
187
|
+
def connected_components_combine(
|
188
|
+
self,
|
189
|
+
logits: torch.Tensor,
|
190
|
+
point_logits: torch.Tensor,
|
191
|
+
point_coords: torch.Tensor,
|
192
|
+
point_labels: torch.Tensor,
|
193
|
+
mapping_index: torch.Tensor,
|
194
|
+
thred: float = 0.5,
|
195
|
+
):
|
196
|
+
"""
|
197
|
+
Combine auto results with point click response. The auto results have shape [B, 1, H, W, D] which means B foreground masks
|
198
|
+
from a single image patch.
|
199
|
+
Out of those B foreground masks, user may add points to a subset of B1 foreground masks for editing.
|
200
|
+
mapping_index represents the correspondence between B and B1.
|
201
|
+
For mapping_index with point clicks, NaN values in logits will be replaced with point_logits. Meanwhile, the added/removed
|
202
|
+
region in point clicks must be updated by the lcc function.
|
203
|
+
Notice, if a positive point is within logits/prev_mask, the components containing the positive point will be added.
|
204
|
+
|
205
|
+
Args:
|
206
|
+
logits: automatic branch results, [B, 1, H, W, D].
|
207
|
+
point_logits: point branch results, [B1, 1, H, W, D].
|
208
|
+
point_coords: point coordinates, [B1, N, 3].
|
209
|
+
point_labels: point labels, [B1, N].
|
210
|
+
mapping_index: [B].
|
211
|
+
thred: the threshold to convert logits to binary.
|
212
|
+
"""
|
213
|
+
logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits
|
214
|
+
_logits = logits[mapping_index]
|
215
|
+
inside = []
|
216
|
+
for i in range(_logits.shape[0]):
|
217
|
+
inside.append(
|
218
|
+
np.any(
|
219
|
+
[
|
220
|
+
_logits[i, 0, p[0], p[1], p[2]].item() > 0
|
221
|
+
for p in point_coords[i].cpu().numpy().round().astype(int)
|
222
|
+
]
|
223
|
+
)
|
224
|
+
)
|
225
|
+
inside_tensor = torch.tensor(inside).to(logits.device)
|
226
|
+
nan_mask = torch.isnan(_logits)
|
227
|
+
# _logits are converted to binary [B1, 1, H, W, D]
|
228
|
+
_logits = torch.nan_to_num(_logits, nan=self.NINF_VALUE).sigmoid()
|
229
|
+
pos_region = point_logits.sigmoid() > thred
|
230
|
+
diff_pos = torch.logical_and(torch.logical_or(_logits <= thred, unsqueeze_right(inside_tensor, 5)), pos_region)
|
231
|
+
diff_neg = torch.logical_and((_logits > thred), ~pos_region)
|
232
|
+
cc = lcc(diff_pos, diff_neg, point_coords=point_coords, point_labels=point_labels)
|
233
|
+
# cc is the region that can be updated by point_logits.
|
234
|
+
cc = cc.to(logits.device)
|
235
|
+
# Need to replace NaN with point_logits. diff_neg will never lie in nan_mask,
|
236
|
+
# only remove unconnected positive region.
|
237
|
+
uc_pos_region = torch.logical_and(pos_region, ~cc)
|
238
|
+
fill_mask = torch.logical_and(nan_mask, uc_pos_region)
|
239
|
+
if fill_mask.any():
|
240
|
+
# fill in the mean negative value
|
241
|
+
point_logits[fill_mask] = -1
|
242
|
+
# replace logits nan value and cc with point_logits
|
243
|
+
cc = torch.logical_or(nan_mask, cc).to(logits.dtype)
|
244
|
+
logits[mapping_index] *= 1 - cc
|
245
|
+
logits[mapping_index] += cc * point_logits
|
246
|
+
return logits
|
247
|
+
|
248
|
+
def gaussian_combine(
|
249
|
+
self,
|
250
|
+
logits: torch.Tensor,
|
251
|
+
point_logits: torch.Tensor,
|
252
|
+
point_coords: torch.Tensor,
|
253
|
+
point_labels: torch.Tensor,
|
254
|
+
mapping_index: torch.Tensor,
|
255
|
+
radius: int | None = None,
|
256
|
+
):
|
257
|
+
"""
|
258
|
+
Combine point results with auto results using gaussian.
|
259
|
+
|
260
|
+
Args:
|
261
|
+
logits: automatic branch results, [B, 1, H, W, D].
|
262
|
+
point_logits: point branch results, [B1, 1, H, W, D].
|
263
|
+
point_coords: point coordinates, [B1, N, 3].
|
264
|
+
point_labels: point labels, [B1, N].
|
265
|
+
mapping_index: [B].
|
266
|
+
radius: gaussian ball radius.
|
267
|
+
"""
|
268
|
+
if radius is None:
|
269
|
+
radius = min(point_logits.shape[-3:]) // 5 # empirical value 5
|
270
|
+
weight = 1 - convert_points_to_disc(point_logits.shape[-3:], point_coords, point_labels, radius=radius).sum(
|
271
|
+
1, keepdims=True
|
272
|
+
)
|
273
|
+
weight[weight < 0] = 0
|
274
|
+
logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits
|
275
|
+
logits[mapping_index] *= weight
|
276
|
+
logits[mapping_index] += (1 - weight) * point_logits
|
277
|
+
return logits
|
278
|
+
|
279
|
+
def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False):
|
280
|
+
"""
|
281
|
+
Freeze auto-branch or point-branch.
|
282
|
+
|
283
|
+
Args:
|
284
|
+
auto_freeze: whether to freeze the auto branch.
|
285
|
+
point_freeze: whether to freeze the point branch.
|
286
|
+
"""
|
287
|
+
if auto_freeze != self.auto_freeze:
|
288
|
+
if hasattr(self.image_encoder, "set_auto_grad"):
|
289
|
+
self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
|
290
|
+
else:
|
291
|
+
for param in self.image_encoder.parameters():
|
292
|
+
param.requires_grad = (not auto_freeze) and (not point_freeze)
|
293
|
+
for param in self.class_head.parameters():
|
294
|
+
param.requires_grad = not auto_freeze
|
295
|
+
self.auto_freeze = auto_freeze
|
296
|
+
|
297
|
+
if point_freeze != self.point_freeze:
|
298
|
+
if hasattr(self.image_encoder, "set_auto_grad"):
|
299
|
+
self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
|
300
|
+
else:
|
301
|
+
for param in self.image_encoder.parameters():
|
302
|
+
param.requires_grad = (not auto_freeze) and (not point_freeze)
|
303
|
+
for param in self.point_head.parameters():
|
304
|
+
param.requires_grad = not point_freeze
|
305
|
+
self.point_freeze = point_freeze
|
306
|
+
|
307
|
+
def forward(
|
308
|
+
self,
|
309
|
+
input_images: torch.Tensor,
|
310
|
+
point_coords: torch.Tensor | None = None,
|
311
|
+
point_labels: torch.Tensor | None = None,
|
312
|
+
class_vector: torch.Tensor | None = None,
|
313
|
+
prompt_class: torch.Tensor | None = None,
|
314
|
+
patch_coords: Sequence[slice] | None = None,
|
315
|
+
labels: torch.Tensor | None = None,
|
316
|
+
label_set: Sequence[int] | None = None,
|
317
|
+
prev_mask: torch.Tensor | None = None,
|
318
|
+
radius: int | None = None,
|
319
|
+
val_point_sampler: Callable | None = None,
|
320
|
+
**kwargs,
|
321
|
+
):
|
322
|
+
"""
|
323
|
+
The forward function for VISTA3D. We only support single patch in training and inference.
|
324
|
+
One exception is allowing sliding window batch size > 1 for automatic segmentation only case.
|
325
|
+
B represents number of objects, N represents number of points for each objects.
|
326
|
+
|
327
|
+
Args:
|
328
|
+
input_images: [1, 1, H, W, D]
|
329
|
+
point_coords: [B, N, 3]
|
330
|
+
point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class.
|
331
|
+
2/3 means negative/postive ponits for special supported class like tumor.
|
332
|
+
class_vector: [B, 1], the global class index
|
333
|
+
prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if
|
334
|
+
the points are for zero-shot or supported class. When class_vector and point_coords are both
|
335
|
+
provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b]
|
336
|
+
will be considered novel class.
|
337
|
+
patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference.
|
338
|
+
This value is passed from sliding_window_inferer. This is an indicator for training phase or validation phase.
|
339
|
+
labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation
|
340
|
+
label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,
|
341
|
+
this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot
|
342
|
+
evaluation, this label_set should be the original index.
|
343
|
+
prev_mask: [B, N, H_fullsize, W_fullsize, D_fullsize].
|
344
|
+
This is the transposed raw output from sliding_window_inferer before any postprocessing.
|
345
|
+
When user click points to perform auto-results correction, this can be the auto-results.
|
346
|
+
radius: single float value controling the gaussian blur when combining point and auto results.
|
347
|
+
The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes.
|
348
|
+
val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation.
|
349
|
+
|
350
|
+
"""
|
351
|
+
image_size = input_images.shape[-3:]
|
352
|
+
device = input_images.device
|
353
|
+
if point_coords is None and class_vector is None:
|
354
|
+
return self.NINF_VALUE + torch.zeros([1, 1, *image_size], device=device)
|
355
|
+
|
356
|
+
bs = self.get_foreground_class_count(class_vector, point_coords)
|
357
|
+
if patch_coords is not None:
|
358
|
+
# if during validation and perform enable based point-validation.
|
359
|
+
if labels is not None and label_set is not None:
|
360
|
+
# if labels is not None, sample from labels for each patch.
|
361
|
+
if val_point_sampler is None:
|
362
|
+
# TODO: think about how to refactor this part.
|
363
|
+
val_point_sampler = self.sample_points_patch_val
|
364
|
+
point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords, label_set)
|
365
|
+
if prompt_class[0].item() == 0: # type: ignore
|
366
|
+
point_labels[0] = -1 # type: ignore
|
367
|
+
labels, prev_mask = None, None
|
368
|
+
elif point_coords is not None:
|
369
|
+
# If not performing patch-based point only validation, use user provided click points for inference.
|
370
|
+
# the point clicks is in original image space, convert it to current patch-coordinate space.
|
371
|
+
point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) # type: ignore
|
372
|
+
|
373
|
+
if point_coords is not None and point_labels is not None:
|
374
|
+
# remove points that used for padding purposes (point_label = -1)
|
375
|
+
mapping_index = ((point_labels != -1).sum(1) > 0).to(torch.bool)
|
376
|
+
if mapping_index.any():
|
377
|
+
point_coords = point_coords[mapping_index]
|
378
|
+
point_labels = point_labels[mapping_index]
|
379
|
+
if prompt_class is not None:
|
380
|
+
prompt_class = prompt_class[mapping_index]
|
381
|
+
else:
|
382
|
+
if self.auto_freeze or (class_vector is None and patch_coords is None):
|
383
|
+
# if auto_freeze, point prompt must exist to allow loss backward
|
384
|
+
# in training, class_vector and point cannot both be None due to loss.backward()
|
385
|
+
mapping_index.fill_(True)
|
386
|
+
else:
|
387
|
+
point_coords, point_labels = None, None
|
388
|
+
|
389
|
+
if point_coords is None and class_vector is None:
|
390
|
+
return self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)
|
391
|
+
|
392
|
+
if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None:
|
393
|
+
out, out_auto = self.image_embeddings, None
|
394
|
+
else:
|
395
|
+
out, out_auto = self.image_encoder(
|
396
|
+
input_images, with_point=point_coords is not None, with_label=class_vector is not None
|
397
|
+
)
|
398
|
+
# release memory
|
399
|
+
input_images = None # type: ignore
|
400
|
+
|
401
|
+
# force releasing memories that set to None
|
402
|
+
torch.cuda.empty_cache()
|
403
|
+
if class_vector is not None:
|
404
|
+
logits, _ = self.class_head(out_auto, class_vector)
|
405
|
+
if point_coords is not None:
|
406
|
+
point_logits = self.point_head(out, point_coords, point_labels, class_vector=prompt_class)
|
407
|
+
if patch_coords is None:
|
408
|
+
logits = self.gaussian_combine(
|
409
|
+
logits, point_logits, point_coords, point_labels, mapping_index, radius # type: ignore
|
410
|
+
)
|
411
|
+
else:
|
412
|
+
# during validation use largest component
|
413
|
+
logits = self.connected_components_combine(
|
414
|
+
logits, point_logits, point_coords, point_labels, mapping_index # type: ignore
|
415
|
+
)
|
416
|
+
else:
|
417
|
+
logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device, dtype=out.dtype)
|
418
|
+
logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class)
|
419
|
+
if prev_mask is not None and patch_coords is not None:
|
420
|
+
logits = self.connected_components_combine(
|
421
|
+
prev_mask[patch_coords].transpose(1, 0).to(logits.device),
|
422
|
+
logits[mapping_index],
|
423
|
+
point_coords, # type: ignore
|
424
|
+
point_labels, # type: ignore
|
425
|
+
mapping_index,
|
426
|
+
)
|
427
|
+
|
428
|
+
if kwargs.get("keep_cache", False) and class_vector is None:
|
429
|
+
self.image_embeddings = out.detach()
|
430
|
+
return logits
|
431
|
+
|
432
|
+
|
433
|
+
class PointMappingSAM(nn.Module):
|
434
|
+
def __init__(self, feature_size: int, max_prompt: int = 32, n_classes: int = 512, last_supported: int = 132):
|
435
|
+
"""Interactive point head used for VISTA3D.
|
436
|
+
Adapted from segment anything:
|
437
|
+
`https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`.
|
438
|
+
|
439
|
+
Args:
|
440
|
+
feature_size: feature channel from encoder.
|
441
|
+
max_prompt: max prompt number in each forward iteration.
|
442
|
+
n_classes: number of classes the model can potentially support. This is the maximum number of class embeddings.
|
443
|
+
last_supported: number of classes the model support, this value should match the trained model weights.
|
444
|
+
"""
|
445
|
+
super().__init__()
|
446
|
+
transformer_dim = feature_size
|
447
|
+
self.max_prompt = max_prompt
|
448
|
+
self.feat_downsample = nn.Sequential(
|
449
|
+
nn.Conv3d(in_channels=feature_size, out_channels=feature_size, kernel_size=3, stride=2, padding=1),
|
450
|
+
nn.InstanceNorm3d(feature_size),
|
451
|
+
nn.GELU(),
|
452
|
+
nn.Conv3d(in_channels=feature_size, out_channels=transformer_dim, kernel_size=3, stride=1, padding=1),
|
453
|
+
nn.InstanceNorm3d(feature_size),
|
454
|
+
)
|
455
|
+
|
456
|
+
self.mask_downsample = nn.Conv3d(in_channels=2, out_channels=2, kernel_size=3, stride=2, padding=1)
|
457
|
+
|
458
|
+
self.transformer = TwoWayTransformer(depth=2, embedding_dim=transformer_dim, mlp_dim=512, num_heads=4)
|
459
|
+
self.pe_layer = PositionEmbeddingRandom(transformer_dim // 2)
|
460
|
+
self.point_embeddings = nn.ModuleList([nn.Embedding(1, transformer_dim), nn.Embedding(1, transformer_dim)])
|
461
|
+
self.not_a_point_embed = nn.Embedding(1, transformer_dim)
|
462
|
+
self.special_class_embed = nn.Embedding(1, transformer_dim)
|
463
|
+
self.mask_tokens = nn.Embedding(1, transformer_dim)
|
464
|
+
|
465
|
+
self.output_upscaling = nn.Sequential(
|
466
|
+
nn.ConvTranspose3d(transformer_dim, transformer_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
|
467
|
+
nn.InstanceNorm3d(transformer_dim),
|
468
|
+
nn.GELU(),
|
469
|
+
nn.Conv3d(transformer_dim, transformer_dim, kernel_size=3, stride=1, padding=1),
|
470
|
+
)
|
471
|
+
|
472
|
+
self.output_hypernetworks_mlps = MLP(transformer_dim, transformer_dim, transformer_dim, 3)
|
473
|
+
# class embedding
|
474
|
+
self.n_classes = n_classes
|
475
|
+
self.last_supported = last_supported
|
476
|
+
self.class_embeddings = nn.Embedding(n_classes, feature_size)
|
477
|
+
self.zeroshot_embed = nn.Embedding(1, transformer_dim)
|
478
|
+
self.supported_embed = nn.Embedding(1, transformer_dim)
|
479
|
+
|
480
|
+
def forward(
|
481
|
+
self,
|
482
|
+
out: torch.Tensor,
|
483
|
+
point_coords: torch.Tensor,
|
484
|
+
point_labels: torch.Tensor,
|
485
|
+
class_vector: torch.Tensor | None = None,
|
486
|
+
):
|
487
|
+
"""Args:
|
488
|
+
out: feature from encoder, [1, C, H, W, C]
|
489
|
+
point_coords: point coordinates, [B, N, 3]
|
490
|
+
point_labels: point labels, [B, N]
|
491
|
+
class_vector: class prompts, [B]
|
492
|
+
"""
|
493
|
+
# downsample out
|
494
|
+
out_low = self.feat_downsample(out)
|
495
|
+
out_shape = tuple(out.shape[-3:])
|
496
|
+
# release memory
|
497
|
+
out = None # type: ignore
|
498
|
+
torch.cuda.empty_cache()
|
499
|
+
# embed points
|
500
|
+
points = point_coords + 0.5 # Shift to center of pixel
|
501
|
+
point_embedding = self.pe_layer.forward_with_coords(points, out_shape) # type: ignore
|
502
|
+
point_embedding[point_labels == -1] = 0.0
|
503
|
+
point_embedding[point_labels == -1] += self.not_a_point_embed.weight
|
504
|
+
point_embedding[point_labels == 0] += self.point_embeddings[0].weight
|
505
|
+
point_embedding[point_labels == 1] += self.point_embeddings[1].weight
|
506
|
+
point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight
|
507
|
+
point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight
|
508
|
+
output_tokens = self.mask_tokens.weight
|
509
|
+
|
510
|
+
output_tokens = output_tokens.unsqueeze(0).expand(point_embedding.size(0), -1, -1)
|
511
|
+
if class_vector is None:
|
512
|
+
tokens_all = torch.cat(
|
513
|
+
(
|
514
|
+
output_tokens,
|
515
|
+
point_embedding,
|
516
|
+
self.supported_embed.weight.unsqueeze(0).expand(point_embedding.size(0), -1, -1),
|
517
|
+
),
|
518
|
+
dim=1,
|
519
|
+
)
|
520
|
+
# tokens_all = torch.cat((output_tokens, point_embedding), dim=1)
|
521
|
+
else:
|
522
|
+
class_embeddings = []
|
523
|
+
for i in class_vector:
|
524
|
+
if i > self.last_supported:
|
525
|
+
class_embeddings.append(self.zeroshot_embed.weight)
|
526
|
+
else:
|
527
|
+
class_embeddings.append(self.supported_embed.weight)
|
528
|
+
tokens_all = torch.cat((output_tokens, point_embedding, torch.stack(class_embeddings)), dim=1)
|
529
|
+
# cross attention
|
530
|
+
masks = []
|
531
|
+
max_prompt = self.max_prompt
|
532
|
+
for i in range(int(np.ceil(tokens_all.shape[0] / max_prompt))):
|
533
|
+
# remove variables in previous for loops to save peak memory for self.transformer
|
534
|
+
src, upscaled_embedding, hyper_in = None, None, None
|
535
|
+
torch.cuda.empty_cache()
|
536
|
+
idx = (i * max_prompt, min((i + 1) * max_prompt, tokens_all.shape[0]))
|
537
|
+
tokens = tokens_all[idx[0] : idx[1]]
|
538
|
+
src = torch.repeat_interleave(out_low, tokens.shape[0], dim=0)
|
539
|
+
pos_src = torch.repeat_interleave(self.pe_layer(out_low.shape[-3:]).unsqueeze(0), tokens.shape[0], dim=0)
|
540
|
+
b, c, h, w, d = src.shape
|
541
|
+
hs, src = self.transformer(src, pos_src, tokens)
|
542
|
+
mask_tokens_out = hs[:, :1, :]
|
543
|
+
hyper_in = self.output_hypernetworks_mlps(mask_tokens_out)
|
544
|
+
src = src.transpose(1, 2).view(b, c, h, w, d) # type: ignore
|
545
|
+
upscaled_embedding = self.output_upscaling(src)
|
546
|
+
b, c, h, w, d = upscaled_embedding.shape
|
547
|
+
mask = hyper_in @ upscaled_embedding.view(b, c, h * w * d)
|
548
|
+
masks.append(mask.view(-1, 1, h, w, d))
|
549
|
+
|
550
|
+
return torch.vstack(masks)
|
551
|
+
|
552
|
+
|
553
|
+
class ClassMappingClassify(nn.Module):
|
554
|
+
"""Class head that performs automatic segmentation based on class vector."""
|
555
|
+
|
556
|
+
def __init__(self, n_classes: int, feature_size: int, use_mlp: bool = True):
|
557
|
+
"""Args:
|
558
|
+
n_classes: maximum number of class embedding.
|
559
|
+
feature_size: class embedding size.
|
560
|
+
use_mlp: use mlp to further map class embedding.
|
561
|
+
"""
|
562
|
+
super().__init__()
|
563
|
+
self.use_mlp = use_mlp
|
564
|
+
if use_mlp:
|
565
|
+
self.mlp = nn.Sequential(
|
566
|
+
nn.Linear(feature_size, feature_size),
|
567
|
+
nn.InstanceNorm1d(1),
|
568
|
+
nn.GELU(),
|
569
|
+
nn.Linear(feature_size, feature_size),
|
570
|
+
)
|
571
|
+
self.class_embeddings = nn.Embedding(n_classes, feature_size)
|
572
|
+
self.image_post_mapping = nn.Sequential(
|
573
|
+
UnetrBasicBlock(
|
574
|
+
spatial_dims=3,
|
575
|
+
in_channels=feature_size,
|
576
|
+
out_channels=feature_size,
|
577
|
+
kernel_size=3,
|
578
|
+
stride=1,
|
579
|
+
norm_name="instance",
|
580
|
+
res_block=True,
|
581
|
+
),
|
582
|
+
UnetrBasicBlock(
|
583
|
+
spatial_dims=3,
|
584
|
+
in_channels=feature_size,
|
585
|
+
out_channels=feature_size,
|
586
|
+
kernel_size=3,
|
587
|
+
stride=1,
|
588
|
+
norm_name="instance",
|
589
|
+
res_block=True,
|
590
|
+
),
|
591
|
+
)
|
592
|
+
|
593
|
+
def forward(self, src: torch.Tensor, class_vector: torch.Tensor):
|
594
|
+
b, c, h, w, d = src.shape
|
595
|
+
src = self.image_post_mapping(src)
|
596
|
+
class_embedding = self.class_embeddings(class_vector)
|
597
|
+
if self.use_mlp:
|
598
|
+
class_embedding = self.mlp(class_embedding)
|
599
|
+
# [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension.
|
600
|
+
masks = []
|
601
|
+
for i in range(b):
|
602
|
+
mask = class_embedding @ src[[i]].view(1, c, h * w * d)
|
603
|
+
masks.append(mask.view(-1, 1, h, w, d))
|
604
|
+
|
605
|
+
return torch.cat(masks, 1), class_embedding
|
606
|
+
|
607
|
+
|
608
|
+
class TwoWayTransformer(nn.Module):
|
609
|
+
def __init__(
|
610
|
+
self,
|
611
|
+
depth: int,
|
612
|
+
embedding_dim: int,
|
613
|
+
num_heads: int,
|
614
|
+
mlp_dim: int,
|
615
|
+
activation: tuple | str = "relu",
|
616
|
+
attention_downsample_rate: int = 2,
|
617
|
+
) -> None:
|
618
|
+
"""
|
619
|
+
A transformer decoder that attends to an input image using
|
620
|
+
queries whose positional embedding is supplied.
|
621
|
+
Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.
|
622
|
+
|
623
|
+
Args:
|
624
|
+
depth: number of layers in the transformer.
|
625
|
+
embedding_dim: the channel dimension for the input embeddings.
|
626
|
+
num_heads: the number of heads for multihead attention. Must divide embedding_dim.
|
627
|
+
mlp_dim: the channel dimension internal to the MLP block.
|
628
|
+
activation: the activation to use in the MLP block.
|
629
|
+
attention_downsample_rate: the rate at which to downsample the image before projecting.
|
630
|
+
"""
|
631
|
+
super().__init__()
|
632
|
+
self.depth = depth
|
633
|
+
self.embedding_dim = embedding_dim
|
634
|
+
self.num_heads = num_heads
|
635
|
+
self.mlp_dim = mlp_dim
|
636
|
+
self.layers = nn.ModuleList()
|
637
|
+
|
638
|
+
for i in range(depth):
|
639
|
+
self.layers.append(
|
640
|
+
TwoWayAttentionBlock(
|
641
|
+
embedding_dim=embedding_dim,
|
642
|
+
num_heads=num_heads,
|
643
|
+
mlp_dim=mlp_dim,
|
644
|
+
activation=activation,
|
645
|
+
attention_downsample_rate=attention_downsample_rate,
|
646
|
+
skip_first_layer_pe=(i == 0),
|
647
|
+
)
|
648
|
+
)
|
649
|
+
|
650
|
+
self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
651
|
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
652
|
+
|
653
|
+
def forward(
|
654
|
+
self, image_embedding: torch.Tensor, image_pe: torch.Tensor, point_embedding: torch.Tensor
|
655
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
656
|
+
"""
|
657
|
+
Args:
|
658
|
+
image_embedding: image to attend to. Should be shape
|
659
|
+
B x embedding_dim x h x w for any h and w.
|
660
|
+
image_pe: the positional encoding to add to the image. Must
|
661
|
+
have the same shape as image_embedding.
|
662
|
+
point_embedding: the embedding to add to the query points.
|
663
|
+
Must have shape B x N_points x embedding_dim for any N_points.
|
664
|
+
|
665
|
+
Returns:
|
666
|
+
torch.Tensor: the processed point_embedding.
|
667
|
+
torch.Tensor: the processed image_embedding.
|
668
|
+
"""
|
669
|
+
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
670
|
+
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
671
|
+
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
672
|
+
|
673
|
+
# Prepare queries
|
674
|
+
queries = point_embedding
|
675
|
+
keys = image_embedding
|
676
|
+
|
677
|
+
# Apply transformer blocks and final layernorm
|
678
|
+
for layer in self.layers:
|
679
|
+
queries, keys = layer(queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe)
|
680
|
+
|
681
|
+
# Apply the final attention layer from the points to the image
|
682
|
+
q = queries + point_embedding
|
683
|
+
k = keys + image_pe
|
684
|
+
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
685
|
+
queries = queries + attn_out
|
686
|
+
queries = self.norm_final_attn(queries)
|
687
|
+
|
688
|
+
return queries, keys
|
689
|
+
|
690
|
+
|
691
|
+
class TwoWayAttentionBlock(nn.Module):
|
692
|
+
def __init__(
|
693
|
+
self,
|
694
|
+
embedding_dim: int,
|
695
|
+
num_heads: int,
|
696
|
+
mlp_dim: int = 2048,
|
697
|
+
activation: tuple | str = "relu",
|
698
|
+
attention_downsample_rate: int = 2,
|
699
|
+
skip_first_layer_pe: bool = False,
|
700
|
+
) -> None:
|
701
|
+
"""
|
702
|
+
A transformer block with four layers: (1) self-attention of sparse
|
703
|
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
704
|
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
705
|
+
inputs.
|
706
|
+
Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.
|
707
|
+
|
708
|
+
Args:
|
709
|
+
embedding_dim: the channel dimension of the embeddings.
|
710
|
+
num_heads: the number of heads in the attention layers.
|
711
|
+
mlp_dim: the hidden dimension of the mlp block.
|
712
|
+
activation: the activation of the mlp block.
|
713
|
+
skip_first_layer_pe: skip the PE on the first layer.
|
714
|
+
"""
|
715
|
+
super().__init__()
|
716
|
+
self.self_attn = Attention(embedding_dim, num_heads)
|
717
|
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
718
|
+
|
719
|
+
self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
720
|
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
721
|
+
|
722
|
+
self.mlp = MLPBlock(hidden_size=embedding_dim, mlp_dim=mlp_dim, act=activation, dropout_mode="vista3d")
|
723
|
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
724
|
+
|
725
|
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
726
|
+
self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
727
|
+
|
728
|
+
self.skip_first_layer_pe = skip_first_layer_pe
|
729
|
+
|
730
|
+
def forward(
|
731
|
+
self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor
|
732
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
733
|
+
# Self attention block
|
734
|
+
if self.skip_first_layer_pe:
|
735
|
+
queries = self.self_attn(q=queries, k=queries, v=queries)
|
736
|
+
else:
|
737
|
+
q = queries + query_pe
|
738
|
+
attn_out = self.self_attn(q=q, k=q, v=queries)
|
739
|
+
queries = queries + attn_out
|
740
|
+
queries = self.norm1(queries)
|
741
|
+
|
742
|
+
# Cross attention block, tokens attending to image embedding
|
743
|
+
q = queries + query_pe
|
744
|
+
k = keys + key_pe
|
745
|
+
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
746
|
+
queries = queries + attn_out
|
747
|
+
queries = self.norm2(queries)
|
748
|
+
|
749
|
+
# MLP block
|
750
|
+
mlp_out = self.mlp(queries)
|
751
|
+
queries = queries + mlp_out
|
752
|
+
queries = self.norm3(queries)
|
753
|
+
|
754
|
+
# Cross attention block, image embedding attending to tokens
|
755
|
+
q = queries + query_pe
|
756
|
+
k = keys + key_pe
|
757
|
+
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
758
|
+
keys = keys + attn_out
|
759
|
+
keys = self.norm4(keys)
|
760
|
+
|
761
|
+
return queries, keys
|
762
|
+
|
763
|
+
|
764
|
+
class Attention(nn.Module):
|
765
|
+
"""
|
766
|
+
An attention layer that allows for downscaling the size of the embedding
|
767
|
+
after projection to queries, keys, and values.
|
768
|
+
Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.
|
769
|
+
|
770
|
+
Args:
|
771
|
+
embedding_dim: the channel dimension of the embeddings.
|
772
|
+
num_heads: the number of heads in the attention layers.
|
773
|
+
downsample_rate: the rate at which to downsample the image before projecting.
|
774
|
+
"""
|
775
|
+
|
776
|
+
def __init__(self, embedding_dim: int, num_heads: int, downsample_rate: int = 1) -> None:
|
777
|
+
super().__init__()
|
778
|
+
self.embedding_dim = embedding_dim
|
779
|
+
self.internal_dim = embedding_dim // downsample_rate
|
780
|
+
self.num_heads = num_heads
|
781
|
+
if not self.internal_dim % num_heads == 0:
|
782
|
+
raise ValueError("num_heads must divide embedding_dim.")
|
783
|
+
|
784
|
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
785
|
+
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
786
|
+
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
787
|
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
788
|
+
|
789
|
+
def _separate_heads(self, x: torch.Tensor, num_heads: int) -> torch.Tensor:
|
790
|
+
b, n, c = x.shape
|
791
|
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
792
|
+
# B x N_heads x N_tokens x C_per_head
|
793
|
+
return x.transpose(1, 2)
|
794
|
+
|
795
|
+
def _recombine_heads(self, x: torch.Tensor) -> torch.Tensor:
|
796
|
+
b, n_heads, n_tokens, c_per_head = x.shape
|
797
|
+
x = x.transpose(1, 2)
|
798
|
+
# B x N_tokens x C
|
799
|
+
return x.reshape(b, n_tokens, n_heads * c_per_head)
|
800
|
+
|
801
|
+
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
802
|
+
# Input projections
|
803
|
+
q = self.q_proj(q)
|
804
|
+
k = self.k_proj(k)
|
805
|
+
v = self.v_proj(v)
|
806
|
+
|
807
|
+
# Separate into heads
|
808
|
+
q = self._separate_heads(q, self.num_heads)
|
809
|
+
k = self._separate_heads(k, self.num_heads)
|
810
|
+
v = self._separate_heads(v, self.num_heads)
|
811
|
+
|
812
|
+
# Attention
|
813
|
+
_, _, _, c_per_head = q.shape
|
814
|
+
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
815
|
+
attn = attn / math.sqrt(c_per_head)
|
816
|
+
attn = torch.softmax(attn, dim=-1)
|
817
|
+
|
818
|
+
# Get output
|
819
|
+
out = attn @ v
|
820
|
+
out = self._recombine_heads(out)
|
821
|
+
out = self.out_proj(out)
|
822
|
+
|
823
|
+
return out
|
824
|
+
|
825
|
+
|
826
|
+
class PositionEmbeddingRandom(nn.Module):
|
827
|
+
"""
|
828
|
+
Positional encoding using random spatial frequencies.
|
829
|
+
Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py`.
|
830
|
+
|
831
|
+
Args:
|
832
|
+
num_pos_feats: the number of positional encoding features.
|
833
|
+
scale: the scale of the positional encoding.
|
834
|
+
"""
|
835
|
+
|
836
|
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
837
|
+
super().__init__()
|
838
|
+
if scale is None or scale <= 0.0:
|
839
|
+
scale = 1.0
|
840
|
+
self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((3, num_pos_feats)))
|
841
|
+
|
842
|
+
def _pe_encoding(self, coords: torch.torch.Tensor) -> torch.torch.Tensor:
|
843
|
+
"""Positionally encode points that are normalized to [0,1]."""
|
844
|
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
845
|
+
coords = 2 * coords - 1
|
846
|
+
# [bs=1,N=2,2] @ [2,128]
|
847
|
+
# [bs=1, N=2, 128]
|
848
|
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
849
|
+
coords = 2 * np.pi * coords
|
850
|
+
# outputs d_1 x ... x d_n x C shape
|
851
|
+
# [bs=1, N=2, 128+128=256]
|
852
|
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
853
|
+
|
854
|
+
def forward(self, size: Tuple[int, int, int]) -> torch.torch.Tensor:
|
855
|
+
"""Generate positional encoding for a grid of the specified size."""
|
856
|
+
h, w, d = size
|
857
|
+
device: Any = self.positional_encoding_gaussian_matrix.device
|
858
|
+
grid = torch.ones((h, w, d), device=device, dtype=torch.float32)
|
859
|
+
x_embed = grid.cumsum(dim=0) - 0.5
|
860
|
+
y_embed = grid.cumsum(dim=1) - 0.5
|
861
|
+
z_embed = grid.cumsum(dim=2) - 0.5
|
862
|
+
x_embed = x_embed / h
|
863
|
+
y_embed = y_embed / w
|
864
|
+
z_embed = z_embed / d
|
865
|
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1))
|
866
|
+
# C x H x W
|
867
|
+
return pe.permute(3, 0, 1, 2)
|
868
|
+
|
869
|
+
def forward_with_coords(
|
870
|
+
self, coords_input: torch.torch.Tensor, image_size: Tuple[int, int, int]
|
871
|
+
) -> torch.torch.Tensor:
|
872
|
+
"""Positionally encode points that are not normalized to [0,1]."""
|
873
|
+
coords = coords_input.clone()
|
874
|
+
coords[:, :, 0] = coords[:, :, 0] / image_size[0]
|
875
|
+
coords[:, :, 1] = coords[:, :, 1] / image_size[1]
|
876
|
+
coords[:, :, 2] = coords[:, :, 2] / image_size[2]
|
877
|
+
# B x N x C
|
878
|
+
return self._pe_encoding(coords.to(torch.float))
|
879
|
+
|
880
|
+
|
881
|
+
class MLP(nn.Module):
|
882
|
+
"""
|
883
|
+
Multi-layer perceptron. This class is only used for `PointMappingSAM`.
|
884
|
+
Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`.
|
885
|
+
|
886
|
+
Args:
|
887
|
+
input_dim: the input dimension.
|
888
|
+
hidden_dim: the hidden dimension.
|
889
|
+
output_dim: the output dimension.
|
890
|
+
num_layers: the number of layers.
|
891
|
+
sigmoid_output: whether to apply a sigmoid activation to the output.
|
892
|
+
"""
|
893
|
+
|
894
|
+
def __init__(
|
895
|
+
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False
|
896
|
+
) -> None:
|
897
|
+
super().__init__()
|
898
|
+
self.num_layers = num_layers
|
899
|
+
h = [hidden_dim] * (num_layers - 1)
|
900
|
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
901
|
+
self.sigmoid_output = sigmoid_output
|
902
|
+
|
903
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
904
|
+
for i, layer in enumerate(self.layers):
|
905
|
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
906
|
+
if self.sigmoid_output:
|
907
|
+
x = F.sigmoid(x)
|
908
|
+
return x
|