monai-weekly 1.4.dev2431__py3-none-any.whl → 1.4.dev2435__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
  4. monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
  5. monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
  6. monai/apps/vista3d/inferer.py +177 -0
  7. monai/apps/vista3d/sampler.py +179 -0
  8. monai/apps/vista3d/transforms.py +224 -0
  9. monai/bundle/scripts.py +29 -17
  10. monai/data/utils.py +1 -1
  11. monai/data/wsi_datasets.py +3 -3
  12. monai/inferers/utils.py +1 -0
  13. monai/losses/__init__.py +1 -0
  14. monai/losses/dice.py +10 -1
  15. monai/losses/nacl_loss.py +139 -0
  16. monai/networks/blocks/crossattention.py +48 -26
  17. monai/networks/blocks/mlp.py +16 -4
  18. monai/networks/blocks/selfattention.py +75 -23
  19. monai/networks/blocks/spatialattention.py +16 -1
  20. monai/networks/blocks/transformerblock.py +17 -2
  21. monai/networks/layers/filtering.py +6 -2
  22. monai/networks/nets/__init__.py +2 -1
  23. monai/networks/nets/autoencoderkl.py +55 -22
  24. monai/networks/nets/cell_sam_wrapper.py +92 -0
  25. monai/networks/nets/controlnet.py +24 -22
  26. monai/networks/nets/diffusion_model_unet.py +159 -19
  27. monai/networks/nets/segresnet_ds.py +127 -1
  28. monai/networks/nets/spade_autoencoderkl.py +22 -0
  29. monai/networks/nets/spade_diffusion_model_unet.py +39 -2
  30. monai/networks/nets/transformer.py +17 -17
  31. monai/networks/nets/vista3d.py +946 -0
  32. monai/networks/utils.py +4 -4
  33. monai/transforms/__init__.py +13 -2
  34. monai/transforms/io/array.py +59 -3
  35. monai/transforms/io/dictionary.py +29 -2
  36. monai/transforms/spatial/functional.py +1 -1
  37. monai/transforms/transform.py +2 -2
  38. monai/transforms/utility/dictionary.py +4 -0
  39. monai/transforms/utils.py +230 -1
  40. monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
  41. monai/transforms/utils_pytorch_numpy_unification.py +2 -2
  42. monai/utils/enums.py +1 -0
  43. monai/utils/module.py +7 -6
  44. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/METADATA +84 -81
  45. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/RECORD +49 -43
  46. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/WHEEL +1 -1
  47. /monai/apps/{generation/maisi/utils → vista3d}/__init__.py +0 -0
  48. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/LICENSE +0 -0
  49. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,946 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ from typing import Any, Callable, Optional, Sequence, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ import monai
23
+ from monai.networks.blocks import MLPBlock, UnetrBasicBlock
24
+ from monai.networks.nets import SegResNetDS2
25
+ from monai.transforms.utils import convert_points_to_disc
26
+ from monai.transforms.utils import keep_merge_components_with_points as lcc
27
+ from monai.transforms.utils import sample_points_from_label
28
+ from monai.utils import optional_import, unsqueeze_left, unsqueeze_right
29
+
30
+ rearrange, _ = optional_import("einops", name="rearrange")
31
+
32
+ __all__ = ["VISTA3D", "vista3d132"]
33
+
34
+
35
+ def vista3d132(encoder_embed_dim: int = 48, in_channels: int = 1):
36
+ """
37
+ Exact VISTA3D network configuration used in https://arxiv.org/abs/2406.05285>`_.
38
+ The model treats class index larger than 132 as zero-shot.
39
+
40
+ Args:
41
+ encoder_embed_dim: hidden dimension for encoder.
42
+ in_channels: input channel number.
43
+ """
44
+ segresnet = SegResNetDS2(
45
+ in_channels=in_channels,
46
+ blocks_down=(1, 2, 2, 4, 4),
47
+ norm="instance",
48
+ out_channels=encoder_embed_dim,
49
+ init_filters=encoder_embed_dim,
50
+ dsdepth=1,
51
+ )
52
+ point_head = PointMappingSAM(feature_size=encoder_embed_dim, n_classes=512, last_supported=132)
53
+ class_head = ClassMappingClassify(n_classes=512, feature_size=encoder_embed_dim, use_mlp=True)
54
+ vista = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head)
55
+ return vista
56
+
57
+
58
+ class VISTA3D(nn.Module):
59
+ """
60
+ VISTA3D based on:
61
+ `VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography
62
+ <https://arxiv.org/abs/2406.05285>`_.
63
+
64
+ Args:
65
+ image_encoder: image encoder backbone for feature extraction.
66
+ class_head: class head used for class index based segmentation
67
+ point_head: point head used for interactive segmetnation
68
+ """
69
+
70
+ def __init__(self, image_encoder: nn.Module, class_head: nn.Module, point_head: nn.Module):
71
+ super().__init__()
72
+ self.image_encoder = image_encoder
73
+ self.class_head = class_head
74
+ self.point_head = point_head
75
+ self.image_embeddings = None
76
+ self.auto_freeze = False
77
+ self.point_freeze = False
78
+ self.NINF_VALUE = -9999
79
+ self.PINF_VALUE = 9999
80
+
81
+ def update_slidingwindow_padding(
82
+ self,
83
+ pad_size: list | None,
84
+ labels: torch.Tensor | None,
85
+ prev_mask: torch.Tensor | None,
86
+ point_coords: torch.Tensor | None,
87
+ ):
88
+ """
89
+ Image has been padded by sliding window inferer.
90
+ The related padding need to be performed outside of slidingwindow inferer.
91
+
92
+ Args:
93
+ pad_size: padding size passed from sliding window inferer.
94
+ labels: image label ground truth.
95
+ prev_mask: previous segmentation mask.
96
+ point_coords: point click coordinates.
97
+ """
98
+ if pad_size is None:
99
+ return labels, prev_mask, point_coords
100
+ if labels is not None:
101
+ labels = F.pad(labels, pad=pad_size, mode="constant", value=0)
102
+ if prev_mask is not None:
103
+ prev_mask = F.pad(prev_mask, pad=pad_size, mode="constant", value=0)
104
+ if point_coords is not None:
105
+ point_coords = point_coords + torch.tensor(
106
+ [pad_size[-2], pad_size[-4], pad_size[-6]], device=point_coords.device
107
+ )
108
+ return labels, prev_mask, point_coords
109
+
110
+ def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int:
111
+ """Get number of foreground classes based on class and point prompt."""
112
+ if class_vector is None:
113
+ if point_coords is None:
114
+ raise ValueError("class_vector and point_coords cannot be both None.")
115
+ return point_coords.shape[0]
116
+ else:
117
+ return class_vector.shape[0]
118
+
119
+ def convert_point_label(
120
+ self,
121
+ point_label: torch.Tensor,
122
+ label_set: Sequence[int] | None = None,
123
+ special_index: Sequence[int] = (23, 24, 25, 26, 27, 57, 128),
124
+ ):
125
+ """
126
+ Convert point label based on its class prompt. For special classes defined in special index,
127
+ the positive/negative point label will be converted from 1/0 to 3/2. The purpose is to separate those
128
+ classes with ambiguous classes.
129
+
130
+ Args:
131
+ point_label: the point label tensor, [B, N].
132
+ label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,
133
+ this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot
134
+ evaluation, this label_set should be the original index.
135
+ special_index: the special class index that needs to be converted.
136
+ """
137
+ if label_set is None:
138
+ return point_label
139
+ if not point_label.shape[0] == len(label_set):
140
+ raise ValueError("point_label and label_set must have the same length.")
141
+
142
+ for i in range(len(label_set)):
143
+ if label_set[i] in special_index:
144
+ for j in range(len(point_label[i])):
145
+ point_label[i, j] = point_label[i, j] + 2 if point_label[i, j] > -1 else point_label[i, j]
146
+ return point_label
147
+
148
+ def sample_points_patch_val(
149
+ self,
150
+ labels: torch.Tensor,
151
+ patch_coords: Sequence[slice],
152
+ label_set: Sequence[int],
153
+ use_center: bool = True,
154
+ mapped_label_set: Sequence[int] | None = None,
155
+ max_ppoint: int = 1,
156
+ max_npoint: int = 0,
157
+ ):
158
+ """
159
+ Sample points for patch during sliding window validation. Only used for point only validation.
160
+
161
+ Args:
162
+ labels: shape [1, 1, H, W, D].
163
+ patch_coords: a sequence of sliding window slice objects.
164
+ label_set: local index, must match values in labels.
165
+ use_center: sample points from the center.
166
+ mapped_label_set: global index, it is used to identify special classes and is the global index
167
+ for the sampled points.
168
+ max_ppoint/max_npoint: positive points and negative points to sample.
169
+ """
170
+ point_coords, point_labels = sample_points_from_label(
171
+ labels[patch_coords],
172
+ label_set,
173
+ max_ppoint=max_ppoint,
174
+ max_npoint=max_npoint,
175
+ device=labels.device,
176
+ use_center=use_center,
177
+ )
178
+ point_labels = self.convert_point_label(point_labels, mapped_label_set)
179
+ return (point_coords, point_labels, torch.tensor(label_set).to(point_coords.device).unsqueeze(-1))
180
+
181
+ def update_point_to_patch(
182
+ self, patch_coords: Sequence[slice], point_coords: torch.Tensor, point_labels: torch.Tensor
183
+ ):
184
+ """
185
+ Update point_coords with respect to patch coords.
186
+ If point is outside of the patch, remove the coordinates and set label to -1.
187
+
188
+ Args:
189
+ patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference.
190
+ This value is passed from sliding_window_inferer.
191
+ point_coords: point coordinates, [B, N, 3].
192
+ point_labels: point labels, [B, N].
193
+ """
194
+ patch_ends = [patch_coords[-3].stop, patch_coords[-2].stop, patch_coords[-1].stop]
195
+ patch_starts = [patch_coords[-3].start, patch_coords[-2].start, patch_coords[-1].start]
196
+ # update point coords
197
+ patch_starts_tensor = unsqueeze_left(torch.tensor(patch_starts, device=point_coords.device), 2)
198
+ patch_ends_tensor = unsqueeze_left(torch.tensor(patch_ends, device=point_coords.device), 2)
199
+ # [1 N 1]
200
+ indices = torch.logical_and(
201
+ ((point_coords - patch_starts_tensor) > 0).all(2), ((patch_ends_tensor - point_coords) > 0).all(2)
202
+ )
203
+ # check if it's within patch coords
204
+ point_coords = point_coords.clone() - patch_starts_tensor
205
+ point_labels = point_labels.clone()
206
+ if indices.any():
207
+ point_labels[~indices] = -1
208
+ point_coords[~indices] = 0
209
+ # also remove padded points, mainly used for inference.
210
+ not_pad_indices = (point_labels != -1).any(0)
211
+ point_coords = point_coords[:, not_pad_indices]
212
+ point_labels = point_labels[:, not_pad_indices]
213
+ return point_coords, point_labels
214
+ return None, None
215
+
216
+ def connected_components_combine(
217
+ self,
218
+ logits: torch.Tensor,
219
+ point_logits: torch.Tensor,
220
+ point_coords: torch.Tensor,
221
+ point_labels: torch.Tensor,
222
+ mapping_index: torch.Tensor,
223
+ thred: float = 0.5,
224
+ ):
225
+ """
226
+ Combine auto results with point click response. The auto results have shape [B, 1, H, W, D] which means B foreground masks
227
+ from a single image patch.
228
+ Out of those B foreground masks, user may add points to a subset of B1 foreground masks for editing.
229
+ mapping_index represents the correspondence between B and B1.
230
+ For mapping_index with point clicks, NaN values in logits will be replaced with point_logits. Meanwhile, the added/removed
231
+ region in point clicks must be updated by the lcc function.
232
+ Notice, if a positive point is within logits/prev_mask, the components containing the positive point will be added.
233
+
234
+ Args:
235
+ logits: automatic branch results, [B, 1, H, W, D].
236
+ point_logits: point branch results, [B1, 1, H, W, D].
237
+ point_coords: point coordinates, [B1, N, 3].
238
+ point_labels: point labels, [B1, N].
239
+ mapping_index: [B].
240
+ thred: the threshold to convert logits to binary.
241
+ """
242
+ logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits
243
+ _logits = logits[mapping_index]
244
+ inside = []
245
+ for i in range(_logits.shape[0]):
246
+ inside.append(
247
+ np.any(
248
+ [
249
+ _logits[i, 0, p[0], p[1], p[2]].item() > 0
250
+ for p in point_coords[i].cpu().numpy().round().astype(int)
251
+ ]
252
+ )
253
+ )
254
+ inside_tensor = torch.tensor(inside).to(logits.device)
255
+ nan_mask = torch.isnan(_logits)
256
+ # _logits are converted to binary [B1, 1, H, W, D]
257
+ _logits = torch.nan_to_num(_logits, nan=self.NINF_VALUE).sigmoid()
258
+ pos_region = point_logits.sigmoid() > thred
259
+ diff_pos = torch.logical_and(torch.logical_or(_logits <= thred, unsqueeze_right(inside_tensor, 5)), pos_region)
260
+ diff_neg = torch.logical_and((_logits > thred), ~pos_region)
261
+ cc = lcc(diff_pos, diff_neg, point_coords=point_coords, point_labels=point_labels)
262
+ # cc is the region that can be updated by point_logits.
263
+ cc = cc.to(logits.device)
264
+ # Need to replace NaN with point_logits. diff_neg will never lie in nan_mask,
265
+ # only remove unconnected positive region.
266
+ uc_pos_region = torch.logical_and(pos_region, ~cc)
267
+ fill_mask = torch.logical_and(nan_mask, uc_pos_region)
268
+ if fill_mask.any():
269
+ # fill in the mean negative value
270
+ point_logits[fill_mask] = -1
271
+ # replace logits nan value and cc with point_logits
272
+ cc = torch.logical_or(nan_mask, cc).to(logits.dtype)
273
+ logits[mapping_index] *= 1 - cc
274
+ logits[mapping_index] += cc * point_logits
275
+ return logits
276
+
277
+ def gaussian_combine(
278
+ self,
279
+ logits: torch.Tensor,
280
+ point_logits: torch.Tensor,
281
+ point_coords: torch.Tensor,
282
+ point_labels: torch.Tensor,
283
+ mapping_index: torch.Tensor,
284
+ radius: int | None = None,
285
+ ):
286
+ """
287
+ Combine point results with auto results using gaussian.
288
+
289
+ Args:
290
+ logits: automatic branch results, [B, 1, H, W, D].
291
+ point_logits: point branch results, [B1, 1, H, W, D].
292
+ point_coords: point coordinates, [B1, N, 3].
293
+ point_labels: point labels, [B1, N].
294
+ mapping_index: [B].
295
+ radius: gaussian ball radius.
296
+ """
297
+ if radius is None:
298
+ radius = min(point_logits.shape[-3:]) // 5 # empirical value 5
299
+ weight = 1 - convert_points_to_disc(point_logits.shape[-3:], point_coords, point_labels, radius=radius).sum(
300
+ 1, keepdims=True
301
+ )
302
+ weight[weight < 0] = 0
303
+ logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits
304
+ logits[mapping_index] *= weight
305
+ logits[mapping_index] += (1 - weight) * point_logits
306
+ return logits
307
+
308
+ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False):
309
+ """
310
+ Freeze auto-branch or point-branch.
311
+
312
+ Args:
313
+ auto_freeze: whether to freeze the auto branch.
314
+ point_freeze: whether to freeze the point branch.
315
+ """
316
+ if auto_freeze != self.auto_freeze:
317
+ if hasattr(self.image_encoder, "set_auto_grad"):
318
+ self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
319
+ else:
320
+ for param in self.image_encoder.parameters():
321
+ param.requires_grad = (not auto_freeze) and (not point_freeze)
322
+ for param in self.class_head.parameters():
323
+ param.requires_grad = not auto_freeze
324
+ self.auto_freeze = auto_freeze
325
+
326
+ if point_freeze != self.point_freeze:
327
+ if hasattr(self.image_encoder, "set_auto_grad"):
328
+ self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
329
+ else:
330
+ for param in self.image_encoder.parameters():
331
+ param.requires_grad = (not auto_freeze) and (not point_freeze)
332
+ for param in self.point_head.parameters():
333
+ param.requires_grad = not point_freeze
334
+ self.point_freeze = point_freeze
335
+
336
+ def forward(
337
+ self,
338
+ input_images: torch.Tensor,
339
+ patch_coords: Sequence[slice] | None = None,
340
+ point_coords: torch.Tensor | None = None,
341
+ point_labels: torch.Tensor | None = None,
342
+ class_vector: torch.Tensor | None = None,
343
+ prompt_class: torch.Tensor | None = None,
344
+ labels: torch.Tensor | None = None,
345
+ label_set: Sequence[int] | None = None,
346
+ prev_mask: torch.Tensor | None = None,
347
+ radius: int | None = None,
348
+ val_point_sampler: Callable | None = None,
349
+ transpose: bool = False,
350
+ **kwargs,
351
+ ):
352
+ """
353
+ The forward function for VISTA3D. We only support single patch in training and inference.
354
+ One exception is allowing sliding window batch size > 1 for automatic segmentation only case.
355
+ B represents number of objects, N represents number of points for each objects.
356
+
357
+ Args:
358
+ input_images: [1, 1, H, W, D]
359
+ point_coords: [B, N, 3]
360
+ point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class.
361
+ 2/3 means negative/postive ponits for special supported class like tumor.
362
+ class_vector: [B, 1], the global class index.
363
+ prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if
364
+ the points are for zero-shot or supported class. When class_vector and point_coords are both
365
+ provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b]
366
+ will be considered novel class.
367
+ patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference.
368
+ This value is passed from sliding_window_inferer. This is an indicator for training phase or validation phase.
369
+ labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation
370
+ label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,
371
+ this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot
372
+ evaluation, this label_set should be the original index.
373
+ prev_mask: [B, N, H_fullsize, W_fullsize, D_fullsize].
374
+ This is the transposed raw output from sliding_window_inferer before any postprocessing.
375
+ When user click points to perform auto-results correction, this can be the auto-results.
376
+ radius: single float value controling the gaussian blur when combining point and auto results.
377
+ The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes.
378
+ val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation.
379
+ transpose: bool. If true, the output will be transposed to be [1, B, H, W, D]. Required to be true if calling from
380
+ sliding window inferer/point inferer.
381
+ """
382
+ labels, prev_mask, point_coords = self.update_slidingwindow_padding(
383
+ kwargs.get("pad_size", None), labels, prev_mask, point_coords
384
+ )
385
+ image_size = input_images.shape[-3:]
386
+ device = input_images.device
387
+ if point_coords is None and class_vector is None:
388
+ return self.NINF_VALUE + torch.zeros([1, 1, *image_size], device=device)
389
+
390
+ bs = self.get_foreground_class_count(class_vector, point_coords)
391
+ if patch_coords is not None:
392
+ # if during validation and perform enable based point-validation.
393
+ if labels is not None and label_set is not None:
394
+ # if labels is not None, sample from labels for each patch.
395
+ if val_point_sampler is None:
396
+ # TODO: think about how to refactor this part.
397
+ val_point_sampler = self.sample_points_patch_val
398
+ point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords, label_set)
399
+ if prompt_class[0].item() == 0: # type: ignore
400
+ point_labels[0] = -1 # type: ignore
401
+ labels, prev_mask = None, None
402
+ elif point_coords is not None:
403
+ # If not performing patch-based point only validation, use user provided click points for inference.
404
+ # the point clicks is in original image space, convert it to current patch-coordinate space.
405
+ point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) # type: ignore
406
+
407
+ if point_coords is not None and point_labels is not None:
408
+ # remove points that used for padding purposes (point_label = -1)
409
+ mapping_index = ((point_labels != -1).sum(1) > 0).to(torch.bool)
410
+ if mapping_index.any():
411
+ point_coords = point_coords[mapping_index]
412
+ point_labels = point_labels[mapping_index]
413
+ if prompt_class is not None:
414
+ prompt_class = prompt_class[mapping_index]
415
+ else:
416
+ if self.auto_freeze or (class_vector is None and patch_coords is None):
417
+ # if auto_freeze, point prompt must exist to allow loss backward
418
+ # in training, class_vector and point cannot both be None due to loss.backward()
419
+ mapping_index.fill_(True)
420
+ else:
421
+ point_coords, point_labels = None, None
422
+
423
+ if point_coords is None and class_vector is None:
424
+ logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)
425
+ if transpose:
426
+ logits = logits.transpose(1, 0)
427
+ return logits
428
+
429
+ if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None:
430
+ out, out_auto = self.image_embeddings, None
431
+ else:
432
+ out, out_auto = self.image_encoder(
433
+ input_images, with_point=point_coords is not None, with_label=class_vector is not None
434
+ )
435
+ # release memory
436
+ input_images = None # type: ignore
437
+
438
+ # force releasing memories that set to None
439
+ torch.cuda.empty_cache()
440
+ if class_vector is not None:
441
+ logits, _ = self.class_head(out_auto, class_vector)
442
+ if point_coords is not None:
443
+ point_logits = self.point_head(out, point_coords, point_labels, class_vector=prompt_class)
444
+ if patch_coords is None:
445
+ logits = self.gaussian_combine(
446
+ logits, point_logits, point_coords, point_labels, mapping_index, radius # type: ignore
447
+ )
448
+ else:
449
+ # during validation use largest component
450
+ logits = self.connected_components_combine(
451
+ logits, point_logits, point_coords, point_labels, mapping_index # type: ignore
452
+ )
453
+ else:
454
+ logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device, dtype=out.dtype)
455
+ logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class)
456
+ if prev_mask is not None and patch_coords is not None:
457
+ logits = self.connected_components_combine(
458
+ prev_mask[patch_coords].transpose(1, 0).to(logits.device),
459
+ logits[mapping_index],
460
+ point_coords, # type: ignore
461
+ point_labels, # type: ignore
462
+ mapping_index,
463
+ )
464
+ if kwargs.get("keep_cache", False) and class_vector is None:
465
+ self.image_embeddings = out.detach()
466
+ if transpose:
467
+ logits = logits.transpose(1, 0)
468
+ return logits
469
+
470
+
471
+ class PointMappingSAM(nn.Module):
472
+ def __init__(self, feature_size: int, max_prompt: int = 32, n_classes: int = 512, last_supported: int = 132):
473
+ """Interactive point head used for VISTA3D.
474
+ Adapted from segment anything:
475
+ `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`.
476
+
477
+ Args:
478
+ feature_size: feature channel from encoder.
479
+ max_prompt: max prompt number in each forward iteration.
480
+ n_classes: number of classes the model can potentially support. This is the maximum number of class embeddings.
481
+ last_supported: number of classes the model support, this value should match the trained model weights.
482
+ """
483
+ super().__init__()
484
+ transformer_dim = feature_size
485
+ self.max_prompt = max_prompt
486
+ self.feat_downsample = nn.Sequential(
487
+ nn.Conv3d(in_channels=feature_size, out_channels=feature_size, kernel_size=3, stride=2, padding=1),
488
+ nn.InstanceNorm3d(feature_size),
489
+ nn.GELU(),
490
+ nn.Conv3d(in_channels=feature_size, out_channels=transformer_dim, kernel_size=3, stride=1, padding=1),
491
+ nn.InstanceNorm3d(feature_size),
492
+ )
493
+
494
+ self.mask_downsample = nn.Conv3d(in_channels=2, out_channels=2, kernel_size=3, stride=2, padding=1)
495
+
496
+ self.transformer = TwoWayTransformer(depth=2, embedding_dim=transformer_dim, mlp_dim=512, num_heads=4)
497
+ self.pe_layer = PositionEmbeddingRandom(transformer_dim // 2)
498
+ self.point_embeddings = nn.ModuleList([nn.Embedding(1, transformer_dim), nn.Embedding(1, transformer_dim)])
499
+ self.not_a_point_embed = nn.Embedding(1, transformer_dim)
500
+ self.special_class_embed = nn.Embedding(1, transformer_dim)
501
+ self.mask_tokens = nn.Embedding(1, transformer_dim)
502
+
503
+ self.output_upscaling = nn.Sequential(
504
+ nn.ConvTranspose3d(transformer_dim, transformer_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
505
+ nn.InstanceNorm3d(transformer_dim),
506
+ nn.GELU(),
507
+ nn.Conv3d(transformer_dim, transformer_dim, kernel_size=3, stride=1, padding=1),
508
+ )
509
+
510
+ self.output_hypernetworks_mlps = MLP(transformer_dim, transformer_dim, transformer_dim, 3)
511
+ # class embedding
512
+ self.n_classes = n_classes
513
+ self.last_supported = last_supported
514
+ self.class_embeddings = nn.Embedding(n_classes, feature_size)
515
+ self.zeroshot_embed = nn.Embedding(1, transformer_dim)
516
+ self.supported_embed = nn.Embedding(1, transformer_dim)
517
+
518
+ def forward(
519
+ self,
520
+ out: torch.Tensor,
521
+ point_coords: torch.Tensor,
522
+ point_labels: torch.Tensor,
523
+ class_vector: torch.Tensor | None = None,
524
+ ):
525
+ """Args:
526
+ out: feature from encoder, [1, C, H, W, C]
527
+ point_coords: point coordinates, [B, N, 3]
528
+ point_labels: point labels, [B, N]
529
+ class_vector: class prompts, [B]
530
+ """
531
+ # downsample out
532
+ out_low = self.feat_downsample(out)
533
+ out_shape = tuple(out.shape[-3:])
534
+ # release memory
535
+ out = None # type: ignore
536
+ torch.cuda.empty_cache()
537
+ # embed points
538
+ points = point_coords + 0.5 # Shift to center of pixel
539
+ point_embedding = self.pe_layer.forward_with_coords(points, out_shape) # type: ignore
540
+ point_embedding[point_labels == -1] = 0.0
541
+ point_embedding[point_labels == -1] += self.not_a_point_embed.weight
542
+ point_embedding[point_labels == 0] += self.point_embeddings[0].weight
543
+ point_embedding[point_labels == 1] += self.point_embeddings[1].weight
544
+ point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight
545
+ point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight
546
+ output_tokens = self.mask_tokens.weight
547
+
548
+ output_tokens = output_tokens.unsqueeze(0).expand(point_embedding.size(0), -1, -1)
549
+ if class_vector is None:
550
+ tokens_all = torch.cat(
551
+ (
552
+ output_tokens,
553
+ point_embedding,
554
+ self.supported_embed.weight.unsqueeze(0).expand(point_embedding.size(0), -1, -1),
555
+ ),
556
+ dim=1,
557
+ )
558
+ # tokens_all = torch.cat((output_tokens, point_embedding), dim=1)
559
+ else:
560
+ class_embeddings = []
561
+ for i in class_vector:
562
+ if i > self.last_supported:
563
+ class_embeddings.append(self.zeroshot_embed.weight)
564
+ else:
565
+ class_embeddings.append(self.supported_embed.weight)
566
+ tokens_all = torch.cat((output_tokens, point_embedding, torch.stack(class_embeddings)), dim=1)
567
+ # cross attention
568
+ masks = []
569
+ max_prompt = self.max_prompt
570
+ for i in range(int(np.ceil(tokens_all.shape[0] / max_prompt))):
571
+ # remove variables in previous for loops to save peak memory for self.transformer
572
+ src, upscaled_embedding, hyper_in = None, None, None
573
+ torch.cuda.empty_cache()
574
+ idx = (i * max_prompt, min((i + 1) * max_prompt, tokens_all.shape[0]))
575
+ tokens = tokens_all[idx[0] : idx[1]]
576
+ src = torch.repeat_interleave(out_low, tokens.shape[0], dim=0)
577
+ pos_src = torch.repeat_interleave(self.pe_layer(out_low.shape[-3:]).unsqueeze(0), tokens.shape[0], dim=0)
578
+ b, c, h, w, d = src.shape
579
+ hs, src = self.transformer(src, pos_src, tokens)
580
+ mask_tokens_out = hs[:, :1, :]
581
+ hyper_in = self.output_hypernetworks_mlps(mask_tokens_out)
582
+ src = src.transpose(1, 2).view(b, c, h, w, d) # type: ignore
583
+ upscaled_embedding = self.output_upscaling(src)
584
+ b, c, h, w, d = upscaled_embedding.shape
585
+ mask = hyper_in @ upscaled_embedding.view(b, c, h * w * d)
586
+ masks.append(mask.view(-1, 1, h, w, d))
587
+
588
+ return torch.vstack(masks)
589
+
590
+
591
+ class ClassMappingClassify(nn.Module):
592
+ """Class head that performs automatic segmentation based on class vector."""
593
+
594
+ def __init__(self, n_classes: int, feature_size: int, use_mlp: bool = True):
595
+ """Args:
596
+ n_classes: maximum number of class embedding.
597
+ feature_size: class embedding size.
598
+ use_mlp: use mlp to further map class embedding.
599
+ """
600
+ super().__init__()
601
+ self.use_mlp = use_mlp
602
+ if use_mlp:
603
+ self.mlp = nn.Sequential(
604
+ nn.Linear(feature_size, feature_size),
605
+ nn.InstanceNorm1d(1),
606
+ nn.GELU(),
607
+ nn.Linear(feature_size, feature_size),
608
+ )
609
+ self.class_embeddings = nn.Embedding(n_classes, feature_size)
610
+ self.image_post_mapping = nn.Sequential(
611
+ UnetrBasicBlock(
612
+ spatial_dims=3,
613
+ in_channels=feature_size,
614
+ out_channels=feature_size,
615
+ kernel_size=3,
616
+ stride=1,
617
+ norm_name="instance",
618
+ res_block=True,
619
+ ),
620
+ UnetrBasicBlock(
621
+ spatial_dims=3,
622
+ in_channels=feature_size,
623
+ out_channels=feature_size,
624
+ kernel_size=3,
625
+ stride=1,
626
+ norm_name="instance",
627
+ res_block=True,
628
+ ),
629
+ )
630
+
631
+ def forward(self, src: torch.Tensor, class_vector: torch.Tensor):
632
+ b, c, h, w, d = src.shape
633
+ src = self.image_post_mapping(src)
634
+ class_embedding = self.class_embeddings(class_vector)
635
+ if self.use_mlp:
636
+ class_embedding = self.mlp(class_embedding)
637
+ # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension.
638
+ masks = []
639
+ for i in range(b):
640
+ mask = class_embedding @ src[[i]].view(1, c, h * w * d)
641
+ masks.append(mask.view(-1, 1, h, w, d))
642
+
643
+ return torch.cat(masks, 1), class_embedding
644
+
645
+
646
+ class TwoWayTransformer(nn.Module):
647
+ def __init__(
648
+ self,
649
+ depth: int,
650
+ embedding_dim: int,
651
+ num_heads: int,
652
+ mlp_dim: int,
653
+ activation: tuple | str = "relu",
654
+ attention_downsample_rate: int = 2,
655
+ ) -> None:
656
+ """
657
+ A transformer decoder that attends to an input image using
658
+ queries whose positional embedding is supplied.
659
+ Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.
660
+
661
+ Args:
662
+ depth: number of layers in the transformer.
663
+ embedding_dim: the channel dimension for the input embeddings.
664
+ num_heads: the number of heads for multihead attention. Must divide embedding_dim.
665
+ mlp_dim: the channel dimension internal to the MLP block.
666
+ activation: the activation to use in the MLP block.
667
+ attention_downsample_rate: the rate at which to downsample the image before projecting.
668
+ """
669
+ super().__init__()
670
+ self.depth = depth
671
+ self.embedding_dim = embedding_dim
672
+ self.num_heads = num_heads
673
+ self.mlp_dim = mlp_dim
674
+ self.layers = nn.ModuleList()
675
+
676
+ for i in range(depth):
677
+ self.layers.append(
678
+ TwoWayAttentionBlock(
679
+ embedding_dim=embedding_dim,
680
+ num_heads=num_heads,
681
+ mlp_dim=mlp_dim,
682
+ activation=activation,
683
+ attention_downsample_rate=attention_downsample_rate,
684
+ skip_first_layer_pe=(i == 0),
685
+ )
686
+ )
687
+
688
+ self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
689
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
690
+
691
+ def forward(
692
+ self, image_embedding: torch.Tensor, image_pe: torch.Tensor, point_embedding: torch.Tensor
693
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
694
+ """
695
+ Args:
696
+ image_embedding: image to attend to. Should be shape
697
+ B x embedding_dim x h x w for any h and w.
698
+ image_pe: the positional encoding to add to the image. Must
699
+ have the same shape as image_embedding.
700
+ point_embedding: the embedding to add to the query points.
701
+ Must have shape B x N_points x embedding_dim for any N_points.
702
+
703
+ Returns:
704
+ torch.Tensor: the processed point_embedding.
705
+ torch.Tensor: the processed image_embedding.
706
+ """
707
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
708
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
709
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
710
+
711
+ # Prepare queries
712
+ queries = point_embedding
713
+ keys = image_embedding
714
+
715
+ # Apply transformer blocks and final layernorm
716
+ for layer in self.layers:
717
+ queries, keys = layer(queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe)
718
+
719
+ # Apply the final attention layer from the points to the image
720
+ q = queries + point_embedding
721
+ k = keys + image_pe
722
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
723
+ queries = queries + attn_out
724
+ queries = self.norm_final_attn(queries)
725
+
726
+ return queries, keys
727
+
728
+
729
+ class TwoWayAttentionBlock(nn.Module):
730
+ def __init__(
731
+ self,
732
+ embedding_dim: int,
733
+ num_heads: int,
734
+ mlp_dim: int = 2048,
735
+ activation: tuple | str = "relu",
736
+ attention_downsample_rate: int = 2,
737
+ skip_first_layer_pe: bool = False,
738
+ ) -> None:
739
+ """
740
+ A transformer block with four layers: (1) self-attention of sparse
741
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
742
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
743
+ inputs.
744
+ Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.
745
+
746
+ Args:
747
+ embedding_dim: the channel dimension of the embeddings.
748
+ num_heads: the number of heads in the attention layers.
749
+ mlp_dim: the hidden dimension of the mlp block.
750
+ activation: the activation of the mlp block.
751
+ skip_first_layer_pe: skip the PE on the first layer.
752
+ """
753
+ super().__init__()
754
+ self.self_attn = Attention(embedding_dim, num_heads)
755
+ self.norm1 = nn.LayerNorm(embedding_dim)
756
+
757
+ self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
758
+ self.norm2 = nn.LayerNorm(embedding_dim)
759
+
760
+ self.mlp = MLPBlock(hidden_size=embedding_dim, mlp_dim=mlp_dim, act=activation, dropout_mode="vista3d")
761
+ self.norm3 = nn.LayerNorm(embedding_dim)
762
+
763
+ self.norm4 = nn.LayerNorm(embedding_dim)
764
+ self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
765
+
766
+ self.skip_first_layer_pe = skip_first_layer_pe
767
+
768
+ def forward(
769
+ self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor
770
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
771
+ # Self attention block
772
+ if self.skip_first_layer_pe:
773
+ queries = self.self_attn(q=queries, k=queries, v=queries)
774
+ else:
775
+ q = queries + query_pe
776
+ attn_out = self.self_attn(q=q, k=q, v=queries)
777
+ queries = queries + attn_out
778
+ queries = self.norm1(queries)
779
+
780
+ # Cross attention block, tokens attending to image embedding
781
+ q = queries + query_pe
782
+ k = keys + key_pe
783
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
784
+ queries = queries + attn_out
785
+ queries = self.norm2(queries)
786
+
787
+ # MLP block
788
+ mlp_out = self.mlp(queries)
789
+ queries = queries + mlp_out
790
+ queries = self.norm3(queries)
791
+
792
+ # Cross attention block, image embedding attending to tokens
793
+ q = queries + query_pe
794
+ k = keys + key_pe
795
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
796
+ keys = keys + attn_out
797
+ keys = self.norm4(keys)
798
+
799
+ return queries, keys
800
+
801
+
802
+ class Attention(nn.Module):
803
+ """
804
+ An attention layer that allows for downscaling the size of the embedding
805
+ after projection to queries, keys, and values.
806
+ Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.
807
+
808
+ Args:
809
+ embedding_dim: the channel dimension of the embeddings.
810
+ num_heads: the number of heads in the attention layers.
811
+ downsample_rate: the rate at which to downsample the image before projecting.
812
+ """
813
+
814
+ def __init__(self, embedding_dim: int, num_heads: int, downsample_rate: int = 1) -> None:
815
+ super().__init__()
816
+ self.embedding_dim = embedding_dim
817
+ self.internal_dim = embedding_dim // downsample_rate
818
+ self.num_heads = num_heads
819
+ if not self.internal_dim % num_heads == 0:
820
+ raise ValueError("num_heads must divide embedding_dim.")
821
+
822
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
823
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
824
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
825
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
826
+
827
+ def _separate_heads(self, x: torch.Tensor, num_heads: int) -> torch.Tensor:
828
+ b, n, c = x.shape
829
+ x = x.reshape(b, n, num_heads, c // num_heads)
830
+ # B x N_heads x N_tokens x C_per_head
831
+ return x.transpose(1, 2)
832
+
833
+ def _recombine_heads(self, x: torch.Tensor) -> torch.Tensor:
834
+ b, n_heads, n_tokens, c_per_head = x.shape
835
+ x = x.transpose(1, 2)
836
+ # B x N_tokens x C
837
+ return x.reshape(b, n_tokens, n_heads * c_per_head)
838
+
839
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
840
+ # Input projections
841
+ q = self.q_proj(q)
842
+ k = self.k_proj(k)
843
+ v = self.v_proj(v)
844
+
845
+ # Separate into heads
846
+ q = self._separate_heads(q, self.num_heads)
847
+ k = self._separate_heads(k, self.num_heads)
848
+ v = self._separate_heads(v, self.num_heads)
849
+
850
+ # Attention
851
+ _, _, _, c_per_head = q.shape
852
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
853
+ attn = attn / math.sqrt(c_per_head)
854
+ attn = torch.softmax(attn, dim=-1)
855
+
856
+ # Get output
857
+ out = attn @ v
858
+ out = self._recombine_heads(out)
859
+ out = self.out_proj(out)
860
+
861
+ return out
862
+
863
+
864
+ class PositionEmbeddingRandom(nn.Module):
865
+ """
866
+ Positional encoding using random spatial frequencies.
867
+ Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py`.
868
+
869
+ Args:
870
+ num_pos_feats: the number of positional encoding features.
871
+ scale: the scale of the positional encoding.
872
+ """
873
+
874
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
875
+ super().__init__()
876
+ if scale is None or scale <= 0.0:
877
+ scale = 1.0
878
+ self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((3, num_pos_feats)))
879
+
880
+ def _pe_encoding(self, coords: torch.torch.Tensor) -> torch.torch.Tensor:
881
+ """Positionally encode points that are normalized to [0,1]."""
882
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
883
+ coords = 2 * coords - 1
884
+ # [bs=1,N=2,2] @ [2,128]
885
+ # [bs=1, N=2, 128]
886
+ coords = coords @ self.positional_encoding_gaussian_matrix
887
+ coords = 2 * np.pi * coords
888
+ # outputs d_1 x ... x d_n x C shape
889
+ # [bs=1, N=2, 128+128=256]
890
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
891
+
892
+ def forward(self, size: Tuple[int, int, int]) -> torch.torch.Tensor:
893
+ """Generate positional encoding for a grid of the specified size."""
894
+ h, w, d = size
895
+ device: Any = self.positional_encoding_gaussian_matrix.device
896
+ grid = torch.ones((h, w, d), device=device, dtype=torch.float32)
897
+ x_embed = grid.cumsum(dim=0) - 0.5
898
+ y_embed = grid.cumsum(dim=1) - 0.5
899
+ z_embed = grid.cumsum(dim=2) - 0.5
900
+ x_embed = x_embed / h
901
+ y_embed = y_embed / w
902
+ z_embed = z_embed / d
903
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1))
904
+ # C x H x W
905
+ return pe.permute(3, 0, 1, 2)
906
+
907
+ def forward_with_coords(
908
+ self, coords_input: torch.torch.Tensor, image_size: Tuple[int, int, int]
909
+ ) -> torch.torch.Tensor:
910
+ """Positionally encode points that are not normalized to [0,1]."""
911
+ coords = coords_input.clone()
912
+ coords[:, :, 0] = coords[:, :, 0] / image_size[0]
913
+ coords[:, :, 1] = coords[:, :, 1] / image_size[1]
914
+ coords[:, :, 2] = coords[:, :, 2] / image_size[2]
915
+ # B x N x C
916
+ return self._pe_encoding(coords.to(torch.float))
917
+
918
+
919
+ class MLP(nn.Module):
920
+ """
921
+ Multi-layer perceptron. This class is only used for `PointMappingSAM`.
922
+ Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`.
923
+
924
+ Args:
925
+ input_dim: the input dimension.
926
+ hidden_dim: the hidden dimension.
927
+ output_dim: the output dimension.
928
+ num_layers: the number of layers.
929
+ sigmoid_output: whether to apply a sigmoid activation to the output.
930
+ """
931
+
932
+ def __init__(
933
+ self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False
934
+ ) -> None:
935
+ super().__init__()
936
+ self.num_layers = num_layers
937
+ h = [hidden_dim] * (num_layers - 1)
938
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
939
+ self.sigmoid_output = sigmoid_output
940
+
941
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
942
+ for i, layer in enumerate(self.layers):
943
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
944
+ if self.sigmoid_output:
945
+ x = F.sigmoid(x)
946
+ return x