neuro-sam 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. neuro_sam/__init__.py +1 -0
  2. neuro_sam/brightest_path_lib/__init__.py +5 -0
  3. neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
  4. neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
  5. neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
  6. neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
  7. neuro_sam/brightest_path_lib/connected_componen.py +329 -0
  8. neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
  9. neuro_sam/brightest_path_lib/cost/cost.py +33 -0
  10. neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
  11. neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
  12. neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
  13. neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
  14. neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
  15. neuro_sam/brightest_path_lib/image/__init__.py +1 -0
  16. neuro_sam/brightest_path_lib/image/stats.py +197 -0
  17. neuro_sam/brightest_path_lib/input/__init__.py +1 -0
  18. neuro_sam/brightest_path_lib/input/inputs.py +14 -0
  19. neuro_sam/brightest_path_lib/node/__init__.py +2 -0
  20. neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
  21. neuro_sam/brightest_path_lib/node/node.py +125 -0
  22. neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
  23. neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
  24. neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
  25. neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
  26. neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
  27. neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
  28. neuro_sam/napari_utils/color_utils.py +135 -0
  29. neuro_sam/napari_utils/contrasting_color_system.py +169 -0
  30. neuro_sam/napari_utils/main_widget.py +1016 -0
  31. neuro_sam/napari_utils/path_tracing_module.py +1016 -0
  32. neuro_sam/napari_utils/punet_widget.py +424 -0
  33. neuro_sam/napari_utils/segmentation_model.py +769 -0
  34. neuro_sam/napari_utils/segmentation_module.py +649 -0
  35. neuro_sam/napari_utils/visualization_module.py +574 -0
  36. neuro_sam/plugin.py +260 -0
  37. neuro_sam/punet/__init__.py +0 -0
  38. neuro_sam/punet/deepd3_model.py +231 -0
  39. neuro_sam/punet/prob_unet_deepd3.py +431 -0
  40. neuro_sam/punet/prob_unet_with_tversky.py +375 -0
  41. neuro_sam/punet/punet_inference.py +236 -0
  42. neuro_sam/punet/run_inference.py +145 -0
  43. neuro_sam/punet/unet_blocks.py +81 -0
  44. neuro_sam/punet/utils.py +52 -0
  45. neuro_sam-0.1.0.dist-info/METADATA +269 -0
  46. neuro_sam-0.1.0.dist-info/RECORD +93 -0
  47. neuro_sam-0.1.0.dist-info/WHEEL +5 -0
  48. neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
  49. neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
  50. neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
  51. sam2/__init__.py +11 -0
  52. sam2/automatic_mask_generator.py +454 -0
  53. sam2/benchmark.py +92 -0
  54. sam2/build_sam.py +174 -0
  55. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  56. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  57. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  58. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  59. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  60. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  61. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  62. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  63. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  64. sam2/configs/train.yaml +335 -0
  65. sam2/modeling/__init__.py +5 -0
  66. sam2/modeling/backbones/__init__.py +5 -0
  67. sam2/modeling/backbones/hieradet.py +317 -0
  68. sam2/modeling/backbones/image_encoder.py +134 -0
  69. sam2/modeling/backbones/utils.py +93 -0
  70. sam2/modeling/memory_attention.py +169 -0
  71. sam2/modeling/memory_encoder.py +181 -0
  72. sam2/modeling/position_encoding.py +239 -0
  73. sam2/modeling/sam/__init__.py +5 -0
  74. sam2/modeling/sam/mask_decoder.py +295 -0
  75. sam2/modeling/sam/prompt_encoder.py +202 -0
  76. sam2/modeling/sam/transformer.py +311 -0
  77. sam2/modeling/sam2_base.py +911 -0
  78. sam2/modeling/sam2_utils.py +323 -0
  79. sam2/sam2.1_hiera_b+.yaml +116 -0
  80. sam2/sam2.1_hiera_l.yaml +120 -0
  81. sam2/sam2.1_hiera_s.yaml +119 -0
  82. sam2/sam2.1_hiera_t.yaml +121 -0
  83. sam2/sam2_hiera_b+.yaml +113 -0
  84. sam2/sam2_hiera_l.yaml +117 -0
  85. sam2/sam2_hiera_s.yaml +116 -0
  86. sam2/sam2_hiera_t.yaml +118 -0
  87. sam2/sam2_image_predictor.py +475 -0
  88. sam2/sam2_video_predictor.py +1222 -0
  89. sam2/sam2_video_predictor_legacy.py +1172 -0
  90. sam2/utils/__init__.py +5 -0
  91. sam2/utils/amg.py +348 -0
  92. sam2/utils/misc.py +349 -0
  93. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,181 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
15
+
16
+
17
+ class MaskDownSampler(nn.Module):
18
+ """
19
+ Progressively downsample a mask by total_stride, each time by stride.
20
+ Note that LayerNorm is applied per *token*, like in ViT.
21
+
22
+ With each downsample (by a factor stride**2), channel capacity increases by the same factor.
23
+ In the end, we linearly project to embed_dim channels.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ embed_dim=256,
29
+ kernel_size=4,
30
+ stride=4,
31
+ padding=0,
32
+ total_stride=16,
33
+ activation=nn.GELU,
34
+ ):
35
+ super().__init__()
36
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
37
+ assert stride**num_layers == total_stride
38
+ self.encoder = nn.Sequential()
39
+ mask_in_chans, mask_out_chans = 1, 1
40
+ for _ in range(num_layers):
41
+ mask_out_chans = mask_in_chans * (stride**2)
42
+ self.encoder.append(
43
+ nn.Conv2d(
44
+ mask_in_chans,
45
+ mask_out_chans,
46
+ kernel_size=kernel_size,
47
+ stride=stride,
48
+ padding=padding,
49
+ )
50
+ )
51
+ self.encoder.append(LayerNorm2d(mask_out_chans))
52
+ self.encoder.append(activation())
53
+ mask_in_chans = mask_out_chans
54
+
55
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
56
+
57
+ def forward(self, x):
58
+ return self.encoder(x)
59
+
60
+
61
+ # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
62
+ class CXBlock(nn.Module):
63
+ r"""ConvNeXt Block. There are two equivalent implementations:
64
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
65
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
66
+ We use (2) as we find it slightly faster in PyTorch
67
+
68
+ Args:
69
+ dim (int): Number of input channels.
70
+ drop_path (float): Stochastic depth rate. Default: 0.0
71
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ dim,
77
+ kernel_size=7,
78
+ padding=3,
79
+ drop_path=0.0,
80
+ layer_scale_init_value=1e-6,
81
+ use_dwconv=True,
82
+ ):
83
+ super().__init__()
84
+ self.dwconv = nn.Conv2d(
85
+ dim,
86
+ dim,
87
+ kernel_size=kernel_size,
88
+ padding=padding,
89
+ groups=dim if use_dwconv else 1,
90
+ ) # depthwise conv
91
+ self.norm = LayerNorm2d(dim, eps=1e-6)
92
+ self.pwconv1 = nn.Linear(
93
+ dim, 4 * dim
94
+ ) # pointwise/1x1 convs, implemented with linear layers
95
+ self.act = nn.GELU()
96
+ self.pwconv2 = nn.Linear(4 * dim, dim)
97
+ self.gamma = (
98
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
99
+ if layer_scale_init_value > 0
100
+ else None
101
+ )
102
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
103
+
104
+ def forward(self, x):
105
+ input = x
106
+ x = self.dwconv(x)
107
+ x = self.norm(x)
108
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
109
+ x = self.pwconv1(x)
110
+ x = self.act(x)
111
+ x = self.pwconv2(x)
112
+ if self.gamma is not None:
113
+ x = self.gamma * x
114
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
115
+
116
+ x = input + self.drop_path(x)
117
+ return x
118
+
119
+
120
+ class Fuser(nn.Module):
121
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
122
+ super().__init__()
123
+ self.proj = nn.Identity()
124
+ self.layers = get_clones(layer, num_layers)
125
+
126
+ if input_projection:
127
+ assert dim is not None
128
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
129
+
130
+ def forward(self, x):
131
+ # normally x: (N, C, H, W)
132
+ x = self.proj(x)
133
+ for layer in self.layers:
134
+ x = layer(x)
135
+ return x
136
+
137
+
138
+ class MemoryEncoder(nn.Module):
139
+ def __init__(
140
+ self,
141
+ out_dim,
142
+ mask_downsampler,
143
+ fuser,
144
+ position_encoding,
145
+ in_dim=256, # in_dim of pix_feats
146
+ ):
147
+ super().__init__()
148
+
149
+ self.mask_downsampler = mask_downsampler
150
+
151
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
152
+ self.fuser = fuser
153
+ self.position_encoding = position_encoding
154
+ self.out_proj = nn.Identity()
155
+ if out_dim != in_dim:
156
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
157
+
158
+ def forward(
159
+ self,
160
+ pix_feat: torch.Tensor,
161
+ masks: torch.Tensor,
162
+ skip_mask_sigmoid: bool = False,
163
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
164
+ ## Process masks
165
+ # sigmoid, so that less domain shift from gt masks which are bool
166
+ if not skip_mask_sigmoid:
167
+ masks = F.sigmoid(masks)
168
+ masks = self.mask_downsampler(masks)
169
+
170
+ ## Fuse pix_feats and downsampled masks
171
+ # in case the visual features are on CPU, cast them to CUDA
172
+ pix_feat = pix_feat.to(masks.device)
173
+
174
+ x = self.pix_feat_proj(pix_feat)
175
+ x = x + masks
176
+ x = self.fuser(x)
177
+ x = self.out_proj(x)
178
+
179
+ pos = self.position_encoding(x).to(x.dtype)
180
+
181
+ return {"vision_features": x, "vision_pos_enc": [pos]}
@@ -0,0 +1,239 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Any, Optional, Tuple
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+
16
+ class PositionEmbeddingSine(nn.Module):
17
+ """
18
+ This is a more standard version of the position embedding, very similar to the one
19
+ used by the Attention Is All You Need paper, generalized to work on images.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ num_pos_feats,
25
+ temperature: int = 10000,
26
+ normalize: bool = True,
27
+ scale: Optional[float] = None,
28
+ # Following settings only relevant
29
+ # for warmping up cache for compilation
30
+ warmup_cache: bool = True,
31
+ image_size: int = 1024,
32
+ strides: Tuple[int] = (4, 8, 16, 32),
33
+ ):
34
+ super().__init__()
35
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
36
+ self.num_pos_feats = num_pos_feats // 2
37
+ self.temperature = temperature
38
+ self.normalize = normalize
39
+ if scale is not None and normalize is False:
40
+ raise ValueError("normalize should be True if scale is passed")
41
+ if scale is None:
42
+ scale = 2 * math.pi
43
+ self.scale = scale
44
+
45
+ self.cache = {}
46
+ if warmup_cache and torch.cuda.is_available():
47
+ # Warmup cache for cuda, to help with compilation
48
+ device = torch.device("cuda")
49
+ for stride in strides:
50
+ cache_key = (image_size // stride, image_size // stride)
51
+ self._pe(1, device, *cache_key)
52
+
53
+ def _encode_xy(self, x, y):
54
+ # The positions are expected to be normalized
55
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
56
+ x_embed = x * self.scale
57
+ y_embed = y * self.scale
58
+
59
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
60
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
61
+
62
+ pos_x = x_embed[:, None] / dim_t
63
+ pos_y = y_embed[:, None] / dim_t
64
+ pos_x = torch.stack(
65
+ (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
66
+ ).flatten(1)
67
+ pos_y = torch.stack(
68
+ (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
69
+ ).flatten(1)
70
+ return pos_x, pos_y
71
+
72
+ @torch.no_grad()
73
+ def encode_boxes(self, x, y, w, h):
74
+ pos_x, pos_y = self._encode_xy(x, y)
75
+ pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
76
+ return pos
77
+
78
+ encode = encode_boxes # Backwards compatibility
79
+
80
+ @torch.no_grad()
81
+ def encode_points(self, x, y, labels):
82
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
83
+ assert bx == by and nx == ny and bx == bl and nx == nl
84
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
85
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
86
+ pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
87
+ return pos
88
+
89
+ @torch.no_grad()
90
+ def _pe(self, B, device, *cache_key):
91
+ H, W = cache_key
92
+ if cache_key in self.cache:
93
+ return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
94
+
95
+ y_embed = (
96
+ torch.arange(1, H + 1, dtype=torch.float32, device=device)
97
+ .view(1, -1, 1)
98
+ .repeat(B, 1, W)
99
+ )
100
+ x_embed = (
101
+ torch.arange(1, W + 1, dtype=torch.float32, device=device)
102
+ .view(1, 1, -1)
103
+ .repeat(B, H, 1)
104
+ )
105
+
106
+ if self.normalize:
107
+ eps = 1e-6
108
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
109
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
110
+
111
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
112
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
113
+
114
+ pos_x = x_embed[:, :, :, None] / dim_t
115
+ pos_y = y_embed[:, :, :, None] / dim_t
116
+ pos_x = torch.stack(
117
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
118
+ ).flatten(3)
119
+ pos_y = torch.stack(
120
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
121
+ ).flatten(3)
122
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
123
+ self.cache[cache_key] = pos[0]
124
+ return pos
125
+
126
+ @torch.no_grad()
127
+ def forward(self, x: torch.Tensor):
128
+ B = x.shape[0]
129
+ cache_key = (x.shape[-2], x.shape[-1])
130
+ return self._pe(B, x.device, *cache_key)
131
+
132
+
133
+ class PositionEmbeddingRandom(nn.Module):
134
+ """
135
+ Positional encoding using random spatial frequencies.
136
+ """
137
+
138
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
139
+ super().__init__()
140
+ if scale is None or scale <= 0.0:
141
+ scale = 1.0
142
+ self.register_buffer(
143
+ "positional_encoding_gaussian_matrix",
144
+ scale * torch.randn((2, num_pos_feats)),
145
+ )
146
+
147
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
148
+ """Positionally encode points that are normalized to [0,1]."""
149
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
150
+ coords = 2 * coords - 1
151
+ coords = coords @ self.positional_encoding_gaussian_matrix
152
+ coords = 2 * np.pi * coords
153
+ # outputs d_1 x ... x d_n x C shape
154
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
155
+
156
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
157
+ """Generate positional encoding for a grid of the specified size."""
158
+ h, w = size
159
+ device: Any = self.positional_encoding_gaussian_matrix.device
160
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
161
+ y_embed = grid.cumsum(dim=0) - 0.5
162
+ x_embed = grid.cumsum(dim=1) - 0.5
163
+ y_embed = y_embed / h
164
+ x_embed = x_embed / w
165
+
166
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
167
+ return pe.permute(2, 0, 1) # C x H x W
168
+
169
+ def forward_with_coords(
170
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
171
+ ) -> torch.Tensor:
172
+ """Positionally encode points that are not normalized to [0,1]."""
173
+ coords = coords_input.clone()
174
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
175
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
176
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
177
+
178
+
179
+ # Rotary Positional Encoding, adapted from:
180
+ # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
181
+ # 2. https://github.com/naver-ai/rope-vit
182
+ # 3. https://github.com/lucidrains/rotary-embedding-torch
183
+
184
+
185
+ def init_t_xy(end_x: int, end_y: int):
186
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
187
+ t_x = (t % end_x).float()
188
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
189
+ return t_x, t_y
190
+
191
+
192
+ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
193
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
194
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
195
+
196
+ t_x, t_y = init_t_xy(end_x, end_y)
197
+ freqs_x = torch.outer(t_x, freqs_x)
198
+ freqs_y = torch.outer(t_y, freqs_y)
199
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
200
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
201
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
202
+
203
+
204
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
205
+ ndim = x.ndim
206
+ assert 0 <= 1 < ndim
207
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
208
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
209
+ return freqs_cis.view(*shape)
210
+
211
+
212
+ def apply_rotary_enc(
213
+ xq: torch.Tensor,
214
+ xk: torch.Tensor,
215
+ freqs_cis: torch.Tensor,
216
+ repeat_freqs_k: bool = False,
217
+ ):
218
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
219
+ xk_ = (
220
+ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
221
+ if xk.shape[-2] != 0
222
+ else None
223
+ )
224
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
225
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
226
+ if xk_ is None:
227
+ # no keys to rotate, due to dropout
228
+ return xq_out.type_as(xq).to(xq.device), xk
229
+ # repeat freqs along seq_len dim to match k seq_len
230
+ if repeat_freqs_k:
231
+ r = xk_.shape[-2] // xq_.shape[-2]
232
+ if freqs_cis.is_cuda:
233
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
234
+ else:
235
+ # torch.repeat on complex numbers may not be supported on non-CUDA devices
236
+ # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
237
+ freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
238
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
239
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
@@ -0,0 +1,5 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.