monai-weekly 1.4.dev2430__py3-none-any.whl → 1.4.dev2434__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
  4. monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
  5. monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
  6. monai/bundle/config_parser.py +2 -2
  7. monai/bundle/reference_resolver.py +18 -1
  8. monai/bundle/scripts.py +45 -22
  9. monai/bundle/utils.py +3 -1
  10. monai/data/utils.py +1 -1
  11. monai/data/wsi_datasets.py +3 -3
  12. monai/losses/__init__.py +1 -0
  13. monai/losses/dice.py +10 -1
  14. monai/losses/nacl_loss.py +139 -0
  15. monai/networks/blocks/crossattention.py +48 -26
  16. monai/networks/blocks/mlp.py +16 -4
  17. monai/networks/blocks/selfattention.py +75 -23
  18. monai/networks/blocks/spatialattention.py +16 -1
  19. monai/networks/blocks/transformerblock.py +17 -2
  20. monai/networks/nets/__init__.py +2 -1
  21. monai/networks/nets/autoencoderkl.py +55 -22
  22. monai/networks/nets/cell_sam_wrapper.py +92 -0
  23. monai/networks/nets/controlnet.py +24 -22
  24. monai/networks/nets/diffusion_model_unet.py +159 -19
  25. monai/networks/nets/segresnet_ds.py +127 -1
  26. monai/networks/nets/spade_autoencoderkl.py +24 -2
  27. monai/networks/nets/spade_diffusion_model_unet.py +39 -2
  28. monai/networks/nets/transformer.py +17 -17
  29. monai/networks/nets/vista3d.py +908 -0
  30. monai/networks/utils.py +3 -3
  31. monai/transforms/__init__.py +1 -0
  32. monai/transforms/io/array.py +1 -1
  33. monai/transforms/post/array.py +2 -1
  34. monai/transforms/spatial/functional.py +1 -1
  35. monai/transforms/transform.py +2 -2
  36. monai/transforms/utils.py +183 -0
  37. monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
  38. monai/transforms/utils_pytorch_numpy_unification.py +2 -2
  39. monai/utils/module.py +7 -6
  40. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/METADATA +83 -81
  41. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/RECORD +44 -41
  42. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/WHEEL +1 -1
  43. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/LICENSE +0 -0
  44. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,908 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ from typing import Any, Callable, Optional, Sequence, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ import monai
23
+ from monai.networks.blocks import MLPBlock, UnetrBasicBlock
24
+ from monai.networks.nets import SegResNetDS2
25
+ from monai.transforms.utils import convert_points_to_disc
26
+ from monai.transforms.utils import get_largest_connected_component_mask_point as lcc
27
+ from monai.transforms.utils import sample_points_from_label
28
+ from monai.utils import optional_import, unsqueeze_left, unsqueeze_right
29
+
30
+ rearrange, _ = optional_import("einops", name="rearrange")
31
+
32
+ __all__ = ["VISTA3D", "vista3d132"]
33
+
34
+
35
+ def vista3d132(encoder_embed_dim: int = 48, in_channels: int = 1):
36
+ """
37
+ Exact VISTA3D network configuration used in https://arxiv.org/abs/2406.05285>`_.
38
+ The model treats class index larger than 132 as zero-shot.
39
+
40
+ Args:
41
+ encoder_embed_dim: hidden dimension for encoder.
42
+ in_channels: input channel number.
43
+ """
44
+ segresnet = SegResNetDS2(
45
+ in_channels=in_channels,
46
+ blocks_down=(1, 2, 2, 4, 4),
47
+ norm="instance",
48
+ out_channels=encoder_embed_dim,
49
+ init_filters=encoder_embed_dim,
50
+ dsdepth=1,
51
+ )
52
+ point_head = PointMappingSAM(feature_size=encoder_embed_dim, n_classes=512, last_supported=132)
53
+ class_head = ClassMappingClassify(n_classes=512, feature_size=encoder_embed_dim, use_mlp=True)
54
+ vista = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head)
55
+ return vista
56
+
57
+
58
+ class VISTA3D(nn.Module):
59
+ """
60
+ VISTA3D based on:
61
+ `VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography
62
+ <https://arxiv.org/abs/2406.05285>`_.
63
+
64
+ Args:
65
+ image_encoder: image encoder backbone for feature extraction.
66
+ class_head: class head used for class index based segmentation
67
+ point_head: point head used for interactive segmetnation
68
+ """
69
+
70
+ def __init__(self, image_encoder: nn.Module, class_head: nn.Module, point_head: nn.Module):
71
+ super().__init__()
72
+ self.image_encoder = image_encoder
73
+ self.class_head = class_head
74
+ self.point_head = point_head
75
+ self.image_embeddings = None
76
+ self.auto_freeze = False
77
+ self.point_freeze = False
78
+ self.NINF_VALUE = -9999
79
+ self.PINF_VALUE = 9999
80
+
81
+ def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int:
82
+ """Get number of foreground classes based on class and point prompt."""
83
+ if class_vector is None:
84
+ if point_coords is None:
85
+ raise ValueError("class_vector and point_coords cannot be both None.")
86
+ return point_coords.shape[0]
87
+ else:
88
+ return class_vector.shape[0]
89
+
90
+ def convert_point_label(
91
+ self,
92
+ point_label: torch.Tensor,
93
+ label_set: Sequence[int] | None = None,
94
+ special_index: Sequence[int] = (23, 24, 25, 26, 27, 57, 128),
95
+ ):
96
+ """
97
+ Convert point label based on its class prompt. For special classes defined in special index,
98
+ the positive/negative point label will be converted from 1/0 to 3/2. The purpose is to separate those
99
+ classes with ambiguous classes.
100
+
101
+ Args:
102
+ point_label: the point label tensor, [B, N].
103
+ label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,
104
+ this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot
105
+ evaluation, this label_set should be the original index.
106
+ special_index: the special class index that needs to be converted.
107
+ """
108
+ if label_set is None:
109
+ return point_label
110
+ if not point_label.shape[0] == len(label_set):
111
+ raise ValueError("point_label and label_set must have the same length.")
112
+
113
+ for i in range(len(label_set)):
114
+ if label_set[i] in special_index:
115
+ for j in range(len(point_label[i])):
116
+ point_label[i, j] = point_label[i, j] + 2 if point_label[i, j] > -1 else point_label[i, j]
117
+ return point_label
118
+
119
+ def sample_points_patch_val(
120
+ self,
121
+ labels: torch.Tensor,
122
+ patch_coords: Sequence[slice],
123
+ label_set: Sequence[int],
124
+ use_center: bool = True,
125
+ mapped_label_set: Sequence[int] | None = None,
126
+ max_ppoint: int = 1,
127
+ max_npoint: int = 0,
128
+ ):
129
+ """
130
+ Sample points for patch during sliding window validation. Only used for point only validation.
131
+
132
+ Args:
133
+ labels: shape [1, 1, H, W, D].
134
+ patch_coords: a sequence of sliding window slice objects.
135
+ label_set: local index, must match values in labels.
136
+ use_center: sample points from the center.
137
+ mapped_label_set: global index, it is used to identify special classes and is the global index
138
+ for the sampled points.
139
+ max_ppoint/max_npoint: positive points and negative points to sample.
140
+ """
141
+ point_coords, point_labels = sample_points_from_label(
142
+ labels[patch_coords],
143
+ label_set,
144
+ max_ppoint=max_ppoint,
145
+ max_npoint=max_npoint,
146
+ device=labels.device,
147
+ use_center=use_center,
148
+ )
149
+ point_labels = self.convert_point_label(point_labels, mapped_label_set)
150
+ return (point_coords, point_labels, torch.tensor(label_set).to(point_coords.device).unsqueeze(-1))
151
+
152
+ def update_point_to_patch(
153
+ self, patch_coords: Sequence[slice], point_coords: torch.Tensor, point_labels: torch.Tensor
154
+ ):
155
+ """
156
+ Update point_coords with respect to patch coords.
157
+ If point is outside of the patch, remove the coordinates and set label to -1.
158
+
159
+ Args:
160
+ patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference.
161
+ This value is passed from sliding_window_inferer.
162
+ point_coords: point coordinates, [B, N, 3].
163
+ point_labels: point labels, [B, N].
164
+ """
165
+ patch_ends = [patch_coords[-3].stop, patch_coords[-2].stop, patch_coords[-1].stop]
166
+ patch_starts = [patch_coords[-3].start, patch_coords[-2].start, patch_coords[-1].start]
167
+ # update point coords
168
+ patch_starts_tensor = unsqueeze_left(torch.tensor(patch_starts, device=point_coords.device), 2)
169
+ patch_ends_tensor = unsqueeze_left(torch.tensor(patch_ends, device=point_coords.device), 2)
170
+ # [1 N 1]
171
+ indices = torch.logical_and(
172
+ ((point_coords - patch_starts_tensor) > 0).all(2), ((patch_ends_tensor - point_coords) > 0).all(2)
173
+ )
174
+ # check if it's within patch coords
175
+ point_coords = point_coords.clone() - patch_starts_tensor
176
+ point_labels = point_labels.clone()
177
+ if indices.any():
178
+ point_labels[~indices] = -1
179
+ point_coords[~indices] = 0
180
+ # also remove padded points, mainly used for inference.
181
+ not_pad_indices = (point_labels != -1).any(0)
182
+ point_coords = point_coords[:, not_pad_indices]
183
+ point_labels = point_labels[:, not_pad_indices]
184
+ return point_coords, point_labels
185
+ return None, None
186
+
187
+ def connected_components_combine(
188
+ self,
189
+ logits: torch.Tensor,
190
+ point_logits: torch.Tensor,
191
+ point_coords: torch.Tensor,
192
+ point_labels: torch.Tensor,
193
+ mapping_index: torch.Tensor,
194
+ thred: float = 0.5,
195
+ ):
196
+ """
197
+ Combine auto results with point click response. The auto results have shape [B, 1, H, W, D] which means B foreground masks
198
+ from a single image patch.
199
+ Out of those B foreground masks, user may add points to a subset of B1 foreground masks for editing.
200
+ mapping_index represents the correspondence between B and B1.
201
+ For mapping_index with point clicks, NaN values in logits will be replaced with point_logits. Meanwhile, the added/removed
202
+ region in point clicks must be updated by the lcc function.
203
+ Notice, if a positive point is within logits/prev_mask, the components containing the positive point will be added.
204
+
205
+ Args:
206
+ logits: automatic branch results, [B, 1, H, W, D].
207
+ point_logits: point branch results, [B1, 1, H, W, D].
208
+ point_coords: point coordinates, [B1, N, 3].
209
+ point_labels: point labels, [B1, N].
210
+ mapping_index: [B].
211
+ thred: the threshold to convert logits to binary.
212
+ """
213
+ logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits
214
+ _logits = logits[mapping_index]
215
+ inside = []
216
+ for i in range(_logits.shape[0]):
217
+ inside.append(
218
+ np.any(
219
+ [
220
+ _logits[i, 0, p[0], p[1], p[2]].item() > 0
221
+ for p in point_coords[i].cpu().numpy().round().astype(int)
222
+ ]
223
+ )
224
+ )
225
+ inside_tensor = torch.tensor(inside).to(logits.device)
226
+ nan_mask = torch.isnan(_logits)
227
+ # _logits are converted to binary [B1, 1, H, W, D]
228
+ _logits = torch.nan_to_num(_logits, nan=self.NINF_VALUE).sigmoid()
229
+ pos_region = point_logits.sigmoid() > thred
230
+ diff_pos = torch.logical_and(torch.logical_or(_logits <= thred, unsqueeze_right(inside_tensor, 5)), pos_region)
231
+ diff_neg = torch.logical_and((_logits > thred), ~pos_region)
232
+ cc = lcc(diff_pos, diff_neg, point_coords=point_coords, point_labels=point_labels)
233
+ # cc is the region that can be updated by point_logits.
234
+ cc = cc.to(logits.device)
235
+ # Need to replace NaN with point_logits. diff_neg will never lie in nan_mask,
236
+ # only remove unconnected positive region.
237
+ uc_pos_region = torch.logical_and(pos_region, ~cc)
238
+ fill_mask = torch.logical_and(nan_mask, uc_pos_region)
239
+ if fill_mask.any():
240
+ # fill in the mean negative value
241
+ point_logits[fill_mask] = -1
242
+ # replace logits nan value and cc with point_logits
243
+ cc = torch.logical_or(nan_mask, cc).to(logits.dtype)
244
+ logits[mapping_index] *= 1 - cc
245
+ logits[mapping_index] += cc * point_logits
246
+ return logits
247
+
248
+ def gaussian_combine(
249
+ self,
250
+ logits: torch.Tensor,
251
+ point_logits: torch.Tensor,
252
+ point_coords: torch.Tensor,
253
+ point_labels: torch.Tensor,
254
+ mapping_index: torch.Tensor,
255
+ radius: int | None = None,
256
+ ):
257
+ """
258
+ Combine point results with auto results using gaussian.
259
+
260
+ Args:
261
+ logits: automatic branch results, [B, 1, H, W, D].
262
+ point_logits: point branch results, [B1, 1, H, W, D].
263
+ point_coords: point coordinates, [B1, N, 3].
264
+ point_labels: point labels, [B1, N].
265
+ mapping_index: [B].
266
+ radius: gaussian ball radius.
267
+ """
268
+ if radius is None:
269
+ radius = min(point_logits.shape[-3:]) // 5 # empirical value 5
270
+ weight = 1 - convert_points_to_disc(point_logits.shape[-3:], point_coords, point_labels, radius=radius).sum(
271
+ 1, keepdims=True
272
+ )
273
+ weight[weight < 0] = 0
274
+ logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits
275
+ logits[mapping_index] *= weight
276
+ logits[mapping_index] += (1 - weight) * point_logits
277
+ return logits
278
+
279
+ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False):
280
+ """
281
+ Freeze auto-branch or point-branch.
282
+
283
+ Args:
284
+ auto_freeze: whether to freeze the auto branch.
285
+ point_freeze: whether to freeze the point branch.
286
+ """
287
+ if auto_freeze != self.auto_freeze:
288
+ if hasattr(self.image_encoder, "set_auto_grad"):
289
+ self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
290
+ else:
291
+ for param in self.image_encoder.parameters():
292
+ param.requires_grad = (not auto_freeze) and (not point_freeze)
293
+ for param in self.class_head.parameters():
294
+ param.requires_grad = not auto_freeze
295
+ self.auto_freeze = auto_freeze
296
+
297
+ if point_freeze != self.point_freeze:
298
+ if hasattr(self.image_encoder, "set_auto_grad"):
299
+ self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
300
+ else:
301
+ for param in self.image_encoder.parameters():
302
+ param.requires_grad = (not auto_freeze) and (not point_freeze)
303
+ for param in self.point_head.parameters():
304
+ param.requires_grad = not point_freeze
305
+ self.point_freeze = point_freeze
306
+
307
+ def forward(
308
+ self,
309
+ input_images: torch.Tensor,
310
+ point_coords: torch.Tensor | None = None,
311
+ point_labels: torch.Tensor | None = None,
312
+ class_vector: torch.Tensor | None = None,
313
+ prompt_class: torch.Tensor | None = None,
314
+ patch_coords: Sequence[slice] | None = None,
315
+ labels: torch.Tensor | None = None,
316
+ label_set: Sequence[int] | None = None,
317
+ prev_mask: torch.Tensor | None = None,
318
+ radius: int | None = None,
319
+ val_point_sampler: Callable | None = None,
320
+ **kwargs,
321
+ ):
322
+ """
323
+ The forward function for VISTA3D. We only support single patch in training and inference.
324
+ One exception is allowing sliding window batch size > 1 for automatic segmentation only case.
325
+ B represents number of objects, N represents number of points for each objects.
326
+
327
+ Args:
328
+ input_images: [1, 1, H, W, D]
329
+ point_coords: [B, N, 3]
330
+ point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class.
331
+ 2/3 means negative/postive ponits for special supported class like tumor.
332
+ class_vector: [B, 1], the global class index
333
+ prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if
334
+ the points are for zero-shot or supported class. When class_vector and point_coords are both
335
+ provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b]
336
+ will be considered novel class.
337
+ patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference.
338
+ This value is passed from sliding_window_inferer. This is an indicator for training phase or validation phase.
339
+ labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation
340
+ label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,
341
+ this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot
342
+ evaluation, this label_set should be the original index.
343
+ prev_mask: [B, N, H_fullsize, W_fullsize, D_fullsize].
344
+ This is the transposed raw output from sliding_window_inferer before any postprocessing.
345
+ When user click points to perform auto-results correction, this can be the auto-results.
346
+ radius: single float value controling the gaussian blur when combining point and auto results.
347
+ The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes.
348
+ val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation.
349
+
350
+ """
351
+ image_size = input_images.shape[-3:]
352
+ device = input_images.device
353
+ if point_coords is None and class_vector is None:
354
+ return self.NINF_VALUE + torch.zeros([1, 1, *image_size], device=device)
355
+
356
+ bs = self.get_foreground_class_count(class_vector, point_coords)
357
+ if patch_coords is not None:
358
+ # if during validation and perform enable based point-validation.
359
+ if labels is not None and label_set is not None:
360
+ # if labels is not None, sample from labels for each patch.
361
+ if val_point_sampler is None:
362
+ # TODO: think about how to refactor this part.
363
+ val_point_sampler = self.sample_points_patch_val
364
+ point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords, label_set)
365
+ if prompt_class[0].item() == 0: # type: ignore
366
+ point_labels[0] = -1 # type: ignore
367
+ labels, prev_mask = None, None
368
+ elif point_coords is not None:
369
+ # If not performing patch-based point only validation, use user provided click points for inference.
370
+ # the point clicks is in original image space, convert it to current patch-coordinate space.
371
+ point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) # type: ignore
372
+
373
+ if point_coords is not None and point_labels is not None:
374
+ # remove points that used for padding purposes (point_label = -1)
375
+ mapping_index = ((point_labels != -1).sum(1) > 0).to(torch.bool)
376
+ if mapping_index.any():
377
+ point_coords = point_coords[mapping_index]
378
+ point_labels = point_labels[mapping_index]
379
+ if prompt_class is not None:
380
+ prompt_class = prompt_class[mapping_index]
381
+ else:
382
+ if self.auto_freeze or (class_vector is None and patch_coords is None):
383
+ # if auto_freeze, point prompt must exist to allow loss backward
384
+ # in training, class_vector and point cannot both be None due to loss.backward()
385
+ mapping_index.fill_(True)
386
+ else:
387
+ point_coords, point_labels = None, None
388
+
389
+ if point_coords is None and class_vector is None:
390
+ return self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)
391
+
392
+ if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None:
393
+ out, out_auto = self.image_embeddings, None
394
+ else:
395
+ out, out_auto = self.image_encoder(
396
+ input_images, with_point=point_coords is not None, with_label=class_vector is not None
397
+ )
398
+ # release memory
399
+ input_images = None # type: ignore
400
+
401
+ # force releasing memories that set to None
402
+ torch.cuda.empty_cache()
403
+ if class_vector is not None:
404
+ logits, _ = self.class_head(out_auto, class_vector)
405
+ if point_coords is not None:
406
+ point_logits = self.point_head(out, point_coords, point_labels, class_vector=prompt_class)
407
+ if patch_coords is None:
408
+ logits = self.gaussian_combine(
409
+ logits, point_logits, point_coords, point_labels, mapping_index, radius # type: ignore
410
+ )
411
+ else:
412
+ # during validation use largest component
413
+ logits = self.connected_components_combine(
414
+ logits, point_logits, point_coords, point_labels, mapping_index # type: ignore
415
+ )
416
+ else:
417
+ logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device, dtype=out.dtype)
418
+ logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class)
419
+ if prev_mask is not None and patch_coords is not None:
420
+ logits = self.connected_components_combine(
421
+ prev_mask[patch_coords].transpose(1, 0).to(logits.device),
422
+ logits[mapping_index],
423
+ point_coords, # type: ignore
424
+ point_labels, # type: ignore
425
+ mapping_index,
426
+ )
427
+
428
+ if kwargs.get("keep_cache", False) and class_vector is None:
429
+ self.image_embeddings = out.detach()
430
+ return logits
431
+
432
+
433
+ class PointMappingSAM(nn.Module):
434
+ def __init__(self, feature_size: int, max_prompt: int = 32, n_classes: int = 512, last_supported: int = 132):
435
+ """Interactive point head used for VISTA3D.
436
+ Adapted from segment anything:
437
+ `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`.
438
+
439
+ Args:
440
+ feature_size: feature channel from encoder.
441
+ max_prompt: max prompt number in each forward iteration.
442
+ n_classes: number of classes the model can potentially support. This is the maximum number of class embeddings.
443
+ last_supported: number of classes the model support, this value should match the trained model weights.
444
+ """
445
+ super().__init__()
446
+ transformer_dim = feature_size
447
+ self.max_prompt = max_prompt
448
+ self.feat_downsample = nn.Sequential(
449
+ nn.Conv3d(in_channels=feature_size, out_channels=feature_size, kernel_size=3, stride=2, padding=1),
450
+ nn.InstanceNorm3d(feature_size),
451
+ nn.GELU(),
452
+ nn.Conv3d(in_channels=feature_size, out_channels=transformer_dim, kernel_size=3, stride=1, padding=1),
453
+ nn.InstanceNorm3d(feature_size),
454
+ )
455
+
456
+ self.mask_downsample = nn.Conv3d(in_channels=2, out_channels=2, kernel_size=3, stride=2, padding=1)
457
+
458
+ self.transformer = TwoWayTransformer(depth=2, embedding_dim=transformer_dim, mlp_dim=512, num_heads=4)
459
+ self.pe_layer = PositionEmbeddingRandom(transformer_dim // 2)
460
+ self.point_embeddings = nn.ModuleList([nn.Embedding(1, transformer_dim), nn.Embedding(1, transformer_dim)])
461
+ self.not_a_point_embed = nn.Embedding(1, transformer_dim)
462
+ self.special_class_embed = nn.Embedding(1, transformer_dim)
463
+ self.mask_tokens = nn.Embedding(1, transformer_dim)
464
+
465
+ self.output_upscaling = nn.Sequential(
466
+ nn.ConvTranspose3d(transformer_dim, transformer_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
467
+ nn.InstanceNorm3d(transformer_dim),
468
+ nn.GELU(),
469
+ nn.Conv3d(transformer_dim, transformer_dim, kernel_size=3, stride=1, padding=1),
470
+ )
471
+
472
+ self.output_hypernetworks_mlps = MLP(transformer_dim, transformer_dim, transformer_dim, 3)
473
+ # class embedding
474
+ self.n_classes = n_classes
475
+ self.last_supported = last_supported
476
+ self.class_embeddings = nn.Embedding(n_classes, feature_size)
477
+ self.zeroshot_embed = nn.Embedding(1, transformer_dim)
478
+ self.supported_embed = nn.Embedding(1, transformer_dim)
479
+
480
+ def forward(
481
+ self,
482
+ out: torch.Tensor,
483
+ point_coords: torch.Tensor,
484
+ point_labels: torch.Tensor,
485
+ class_vector: torch.Tensor | None = None,
486
+ ):
487
+ """Args:
488
+ out: feature from encoder, [1, C, H, W, C]
489
+ point_coords: point coordinates, [B, N, 3]
490
+ point_labels: point labels, [B, N]
491
+ class_vector: class prompts, [B]
492
+ """
493
+ # downsample out
494
+ out_low = self.feat_downsample(out)
495
+ out_shape = tuple(out.shape[-3:])
496
+ # release memory
497
+ out = None # type: ignore
498
+ torch.cuda.empty_cache()
499
+ # embed points
500
+ points = point_coords + 0.5 # Shift to center of pixel
501
+ point_embedding = self.pe_layer.forward_with_coords(points, out_shape) # type: ignore
502
+ point_embedding[point_labels == -1] = 0.0
503
+ point_embedding[point_labels == -1] += self.not_a_point_embed.weight
504
+ point_embedding[point_labels == 0] += self.point_embeddings[0].weight
505
+ point_embedding[point_labels == 1] += self.point_embeddings[1].weight
506
+ point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight
507
+ point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight
508
+ output_tokens = self.mask_tokens.weight
509
+
510
+ output_tokens = output_tokens.unsqueeze(0).expand(point_embedding.size(0), -1, -1)
511
+ if class_vector is None:
512
+ tokens_all = torch.cat(
513
+ (
514
+ output_tokens,
515
+ point_embedding,
516
+ self.supported_embed.weight.unsqueeze(0).expand(point_embedding.size(0), -1, -1),
517
+ ),
518
+ dim=1,
519
+ )
520
+ # tokens_all = torch.cat((output_tokens, point_embedding), dim=1)
521
+ else:
522
+ class_embeddings = []
523
+ for i in class_vector:
524
+ if i > self.last_supported:
525
+ class_embeddings.append(self.zeroshot_embed.weight)
526
+ else:
527
+ class_embeddings.append(self.supported_embed.weight)
528
+ tokens_all = torch.cat((output_tokens, point_embedding, torch.stack(class_embeddings)), dim=1)
529
+ # cross attention
530
+ masks = []
531
+ max_prompt = self.max_prompt
532
+ for i in range(int(np.ceil(tokens_all.shape[0] / max_prompt))):
533
+ # remove variables in previous for loops to save peak memory for self.transformer
534
+ src, upscaled_embedding, hyper_in = None, None, None
535
+ torch.cuda.empty_cache()
536
+ idx = (i * max_prompt, min((i + 1) * max_prompt, tokens_all.shape[0]))
537
+ tokens = tokens_all[idx[0] : idx[1]]
538
+ src = torch.repeat_interleave(out_low, tokens.shape[0], dim=0)
539
+ pos_src = torch.repeat_interleave(self.pe_layer(out_low.shape[-3:]).unsqueeze(0), tokens.shape[0], dim=0)
540
+ b, c, h, w, d = src.shape
541
+ hs, src = self.transformer(src, pos_src, tokens)
542
+ mask_tokens_out = hs[:, :1, :]
543
+ hyper_in = self.output_hypernetworks_mlps(mask_tokens_out)
544
+ src = src.transpose(1, 2).view(b, c, h, w, d) # type: ignore
545
+ upscaled_embedding = self.output_upscaling(src)
546
+ b, c, h, w, d = upscaled_embedding.shape
547
+ mask = hyper_in @ upscaled_embedding.view(b, c, h * w * d)
548
+ masks.append(mask.view(-1, 1, h, w, d))
549
+
550
+ return torch.vstack(masks)
551
+
552
+
553
+ class ClassMappingClassify(nn.Module):
554
+ """Class head that performs automatic segmentation based on class vector."""
555
+
556
+ def __init__(self, n_classes: int, feature_size: int, use_mlp: bool = True):
557
+ """Args:
558
+ n_classes: maximum number of class embedding.
559
+ feature_size: class embedding size.
560
+ use_mlp: use mlp to further map class embedding.
561
+ """
562
+ super().__init__()
563
+ self.use_mlp = use_mlp
564
+ if use_mlp:
565
+ self.mlp = nn.Sequential(
566
+ nn.Linear(feature_size, feature_size),
567
+ nn.InstanceNorm1d(1),
568
+ nn.GELU(),
569
+ nn.Linear(feature_size, feature_size),
570
+ )
571
+ self.class_embeddings = nn.Embedding(n_classes, feature_size)
572
+ self.image_post_mapping = nn.Sequential(
573
+ UnetrBasicBlock(
574
+ spatial_dims=3,
575
+ in_channels=feature_size,
576
+ out_channels=feature_size,
577
+ kernel_size=3,
578
+ stride=1,
579
+ norm_name="instance",
580
+ res_block=True,
581
+ ),
582
+ UnetrBasicBlock(
583
+ spatial_dims=3,
584
+ in_channels=feature_size,
585
+ out_channels=feature_size,
586
+ kernel_size=3,
587
+ stride=1,
588
+ norm_name="instance",
589
+ res_block=True,
590
+ ),
591
+ )
592
+
593
+ def forward(self, src: torch.Tensor, class_vector: torch.Tensor):
594
+ b, c, h, w, d = src.shape
595
+ src = self.image_post_mapping(src)
596
+ class_embedding = self.class_embeddings(class_vector)
597
+ if self.use_mlp:
598
+ class_embedding = self.mlp(class_embedding)
599
+ # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension.
600
+ masks = []
601
+ for i in range(b):
602
+ mask = class_embedding @ src[[i]].view(1, c, h * w * d)
603
+ masks.append(mask.view(-1, 1, h, w, d))
604
+
605
+ return torch.cat(masks, 1), class_embedding
606
+
607
+
608
+ class TwoWayTransformer(nn.Module):
609
+ def __init__(
610
+ self,
611
+ depth: int,
612
+ embedding_dim: int,
613
+ num_heads: int,
614
+ mlp_dim: int,
615
+ activation: tuple | str = "relu",
616
+ attention_downsample_rate: int = 2,
617
+ ) -> None:
618
+ """
619
+ A transformer decoder that attends to an input image using
620
+ queries whose positional embedding is supplied.
621
+ Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.
622
+
623
+ Args:
624
+ depth: number of layers in the transformer.
625
+ embedding_dim: the channel dimension for the input embeddings.
626
+ num_heads: the number of heads for multihead attention. Must divide embedding_dim.
627
+ mlp_dim: the channel dimension internal to the MLP block.
628
+ activation: the activation to use in the MLP block.
629
+ attention_downsample_rate: the rate at which to downsample the image before projecting.
630
+ """
631
+ super().__init__()
632
+ self.depth = depth
633
+ self.embedding_dim = embedding_dim
634
+ self.num_heads = num_heads
635
+ self.mlp_dim = mlp_dim
636
+ self.layers = nn.ModuleList()
637
+
638
+ for i in range(depth):
639
+ self.layers.append(
640
+ TwoWayAttentionBlock(
641
+ embedding_dim=embedding_dim,
642
+ num_heads=num_heads,
643
+ mlp_dim=mlp_dim,
644
+ activation=activation,
645
+ attention_downsample_rate=attention_downsample_rate,
646
+ skip_first_layer_pe=(i == 0),
647
+ )
648
+ )
649
+
650
+ self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
651
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
652
+
653
+ def forward(
654
+ self, image_embedding: torch.Tensor, image_pe: torch.Tensor, point_embedding: torch.Tensor
655
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
656
+ """
657
+ Args:
658
+ image_embedding: image to attend to. Should be shape
659
+ B x embedding_dim x h x w for any h and w.
660
+ image_pe: the positional encoding to add to the image. Must
661
+ have the same shape as image_embedding.
662
+ point_embedding: the embedding to add to the query points.
663
+ Must have shape B x N_points x embedding_dim for any N_points.
664
+
665
+ Returns:
666
+ torch.Tensor: the processed point_embedding.
667
+ torch.Tensor: the processed image_embedding.
668
+ """
669
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
670
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
671
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
672
+
673
+ # Prepare queries
674
+ queries = point_embedding
675
+ keys = image_embedding
676
+
677
+ # Apply transformer blocks and final layernorm
678
+ for layer in self.layers:
679
+ queries, keys = layer(queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe)
680
+
681
+ # Apply the final attention layer from the points to the image
682
+ q = queries + point_embedding
683
+ k = keys + image_pe
684
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
685
+ queries = queries + attn_out
686
+ queries = self.norm_final_attn(queries)
687
+
688
+ return queries, keys
689
+
690
+
691
+ class TwoWayAttentionBlock(nn.Module):
692
+ def __init__(
693
+ self,
694
+ embedding_dim: int,
695
+ num_heads: int,
696
+ mlp_dim: int = 2048,
697
+ activation: tuple | str = "relu",
698
+ attention_downsample_rate: int = 2,
699
+ skip_first_layer_pe: bool = False,
700
+ ) -> None:
701
+ """
702
+ A transformer block with four layers: (1) self-attention of sparse
703
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
704
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
705
+ inputs.
706
+ Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.
707
+
708
+ Args:
709
+ embedding_dim: the channel dimension of the embeddings.
710
+ num_heads: the number of heads in the attention layers.
711
+ mlp_dim: the hidden dimension of the mlp block.
712
+ activation: the activation of the mlp block.
713
+ skip_first_layer_pe: skip the PE on the first layer.
714
+ """
715
+ super().__init__()
716
+ self.self_attn = Attention(embedding_dim, num_heads)
717
+ self.norm1 = nn.LayerNorm(embedding_dim)
718
+
719
+ self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
720
+ self.norm2 = nn.LayerNorm(embedding_dim)
721
+
722
+ self.mlp = MLPBlock(hidden_size=embedding_dim, mlp_dim=mlp_dim, act=activation, dropout_mode="vista3d")
723
+ self.norm3 = nn.LayerNorm(embedding_dim)
724
+
725
+ self.norm4 = nn.LayerNorm(embedding_dim)
726
+ self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
727
+
728
+ self.skip_first_layer_pe = skip_first_layer_pe
729
+
730
+ def forward(
731
+ self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor
732
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
733
+ # Self attention block
734
+ if self.skip_first_layer_pe:
735
+ queries = self.self_attn(q=queries, k=queries, v=queries)
736
+ else:
737
+ q = queries + query_pe
738
+ attn_out = self.self_attn(q=q, k=q, v=queries)
739
+ queries = queries + attn_out
740
+ queries = self.norm1(queries)
741
+
742
+ # Cross attention block, tokens attending to image embedding
743
+ q = queries + query_pe
744
+ k = keys + key_pe
745
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
746
+ queries = queries + attn_out
747
+ queries = self.norm2(queries)
748
+
749
+ # MLP block
750
+ mlp_out = self.mlp(queries)
751
+ queries = queries + mlp_out
752
+ queries = self.norm3(queries)
753
+
754
+ # Cross attention block, image embedding attending to tokens
755
+ q = queries + query_pe
756
+ k = keys + key_pe
757
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
758
+ keys = keys + attn_out
759
+ keys = self.norm4(keys)
760
+
761
+ return queries, keys
762
+
763
+
764
+ class Attention(nn.Module):
765
+ """
766
+ An attention layer that allows for downscaling the size of the embedding
767
+ after projection to queries, keys, and values.
768
+ Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.
769
+
770
+ Args:
771
+ embedding_dim: the channel dimension of the embeddings.
772
+ num_heads: the number of heads in the attention layers.
773
+ downsample_rate: the rate at which to downsample the image before projecting.
774
+ """
775
+
776
+ def __init__(self, embedding_dim: int, num_heads: int, downsample_rate: int = 1) -> None:
777
+ super().__init__()
778
+ self.embedding_dim = embedding_dim
779
+ self.internal_dim = embedding_dim // downsample_rate
780
+ self.num_heads = num_heads
781
+ if not self.internal_dim % num_heads == 0:
782
+ raise ValueError("num_heads must divide embedding_dim.")
783
+
784
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
785
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
786
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
787
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
788
+
789
+ def _separate_heads(self, x: torch.Tensor, num_heads: int) -> torch.Tensor:
790
+ b, n, c = x.shape
791
+ x = x.reshape(b, n, num_heads, c // num_heads)
792
+ # B x N_heads x N_tokens x C_per_head
793
+ return x.transpose(1, 2)
794
+
795
+ def _recombine_heads(self, x: torch.Tensor) -> torch.Tensor:
796
+ b, n_heads, n_tokens, c_per_head = x.shape
797
+ x = x.transpose(1, 2)
798
+ # B x N_tokens x C
799
+ return x.reshape(b, n_tokens, n_heads * c_per_head)
800
+
801
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
802
+ # Input projections
803
+ q = self.q_proj(q)
804
+ k = self.k_proj(k)
805
+ v = self.v_proj(v)
806
+
807
+ # Separate into heads
808
+ q = self._separate_heads(q, self.num_heads)
809
+ k = self._separate_heads(k, self.num_heads)
810
+ v = self._separate_heads(v, self.num_heads)
811
+
812
+ # Attention
813
+ _, _, _, c_per_head = q.shape
814
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
815
+ attn = attn / math.sqrt(c_per_head)
816
+ attn = torch.softmax(attn, dim=-1)
817
+
818
+ # Get output
819
+ out = attn @ v
820
+ out = self._recombine_heads(out)
821
+ out = self.out_proj(out)
822
+
823
+ return out
824
+
825
+
826
+ class PositionEmbeddingRandom(nn.Module):
827
+ """
828
+ Positional encoding using random spatial frequencies.
829
+ Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py`.
830
+
831
+ Args:
832
+ num_pos_feats: the number of positional encoding features.
833
+ scale: the scale of the positional encoding.
834
+ """
835
+
836
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
837
+ super().__init__()
838
+ if scale is None or scale <= 0.0:
839
+ scale = 1.0
840
+ self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((3, num_pos_feats)))
841
+
842
+ def _pe_encoding(self, coords: torch.torch.Tensor) -> torch.torch.Tensor:
843
+ """Positionally encode points that are normalized to [0,1]."""
844
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
845
+ coords = 2 * coords - 1
846
+ # [bs=1,N=2,2] @ [2,128]
847
+ # [bs=1, N=2, 128]
848
+ coords = coords @ self.positional_encoding_gaussian_matrix
849
+ coords = 2 * np.pi * coords
850
+ # outputs d_1 x ... x d_n x C shape
851
+ # [bs=1, N=2, 128+128=256]
852
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
853
+
854
+ def forward(self, size: Tuple[int, int, int]) -> torch.torch.Tensor:
855
+ """Generate positional encoding for a grid of the specified size."""
856
+ h, w, d = size
857
+ device: Any = self.positional_encoding_gaussian_matrix.device
858
+ grid = torch.ones((h, w, d), device=device, dtype=torch.float32)
859
+ x_embed = grid.cumsum(dim=0) - 0.5
860
+ y_embed = grid.cumsum(dim=1) - 0.5
861
+ z_embed = grid.cumsum(dim=2) - 0.5
862
+ x_embed = x_embed / h
863
+ y_embed = y_embed / w
864
+ z_embed = z_embed / d
865
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1))
866
+ # C x H x W
867
+ return pe.permute(3, 0, 1, 2)
868
+
869
+ def forward_with_coords(
870
+ self, coords_input: torch.torch.Tensor, image_size: Tuple[int, int, int]
871
+ ) -> torch.torch.Tensor:
872
+ """Positionally encode points that are not normalized to [0,1]."""
873
+ coords = coords_input.clone()
874
+ coords[:, :, 0] = coords[:, :, 0] / image_size[0]
875
+ coords[:, :, 1] = coords[:, :, 1] / image_size[1]
876
+ coords[:, :, 2] = coords[:, :, 2] / image_size[2]
877
+ # B x N x C
878
+ return self._pe_encoding(coords.to(torch.float))
879
+
880
+
881
+ class MLP(nn.Module):
882
+ """
883
+ Multi-layer perceptron. This class is only used for `PointMappingSAM`.
884
+ Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`.
885
+
886
+ Args:
887
+ input_dim: the input dimension.
888
+ hidden_dim: the hidden dimension.
889
+ output_dim: the output dimension.
890
+ num_layers: the number of layers.
891
+ sigmoid_output: whether to apply a sigmoid activation to the output.
892
+ """
893
+
894
+ def __init__(
895
+ self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False
896
+ ) -> None:
897
+ super().__init__()
898
+ self.num_layers = num_layers
899
+ h = [hidden_dim] * (num_layers - 1)
900
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
901
+ self.sigmoid_output = sigmoid_output
902
+
903
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
904
+ for i, layer in enumerate(self.layers):
905
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
906
+ if self.sigmoid_output:
907
+ x = F.sigmoid(x)
908
+ return x