frontveg 0.1.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- frontveg/__init__.py +11 -0
 - frontveg/_tests/__init__.py +0 -0
 - frontveg/_tests/test_widget.py +66 -0
 - frontveg/_version.py +21 -0
 - frontveg/_widget.py +132 -0
 - frontveg/napari.yaml +14 -0
 - frontveg/utils.py +95 -0
 - frontveg-0.1.dev1.dist-info/METADATA +143 -0
 - frontveg-0.1.dev1.dist-info/RECORD +44 -0
 - frontveg-0.1.dev1.dist-info/WHEEL +5 -0
 - frontveg-0.1.dev1.dist-info/entry_points.txt +2 -0
 - frontveg-0.1.dev1.dist-info/licenses/LICENSE +28 -0
 - frontveg-0.1.dev1.dist-info/top_level.txt +2 -0
 - sam2/__init__.py +11 -0
 - sam2/automatic_mask_generator.py +454 -0
 - sam2/build_sam.py +167 -0
 - sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
 - sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
 - sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
 - sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
 - sam2/modeling/__init__.py +5 -0
 - sam2/modeling/backbones/__init__.py +5 -0
 - sam2/modeling/backbones/hieradet.py +317 -0
 - sam2/modeling/backbones/image_encoder.py +134 -0
 - sam2/modeling/backbones/utils.py +95 -0
 - sam2/modeling/memory_attention.py +169 -0
 - sam2/modeling/memory_encoder.py +181 -0
 - sam2/modeling/position_encoding.py +221 -0
 - sam2/modeling/sam/__init__.py +5 -0
 - sam2/modeling/sam/mask_decoder.py +295 -0
 - sam2/modeling/sam/prompt_encoder.py +182 -0
 - sam2/modeling/sam/transformer.py +360 -0
 - sam2/modeling/sam2_base.py +907 -0
 - sam2/modeling/sam2_utils.py +323 -0
 - sam2/sam2_hiera_b+.yaml +1 -0
 - sam2/sam2_hiera_l.yaml +1 -0
 - sam2/sam2_hiera_s.yaml +1 -0
 - sam2/sam2_hiera_t.yaml +1 -0
 - sam2/sam2_image_predictor.py +466 -0
 - sam2/sam2_video_predictor.py +1172 -0
 - sam2/utils/__init__.py +5 -0
 - sam2/utils/amg.py +348 -0
 - sam2/utils/misc.py +349 -0
 - sam2/utils/transforms.py +118 -0
 
| 
         @@ -0,0 +1,360 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         
     | 
| 
      
 2 
     | 
    
         
            +
            # All rights reserved.
         
     | 
| 
      
 3 
     | 
    
         
            +
             
     | 
| 
      
 4 
     | 
    
         
            +
            # This source code is licensed under the license found in the
         
     | 
| 
      
 5 
     | 
    
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            import contextlib
         
     | 
| 
      
 8 
     | 
    
         
            +
            import math
         
     | 
| 
      
 9 
     | 
    
         
            +
            import warnings
         
     | 
| 
      
 10 
     | 
    
         
            +
            from functools import partial
         
     | 
| 
      
 11 
     | 
    
         
            +
            from typing import Tuple, Type
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 14 
     | 
    
         
            +
            import torch.nn.functional as F
         
     | 
| 
      
 15 
     | 
    
         
            +
            from torch import nn, Tensor
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
            from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
         
     | 
| 
      
 18 
     | 
    
         
            +
            from sam2.modeling.sam2_utils import MLP
         
     | 
| 
      
 19 
     | 
    
         
            +
            from sam2.utils.misc import get_sdpa_settings
         
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
            warnings.simplefilter(action="ignore", category=FutureWarning)
         
     | 
| 
      
 22 
     | 
    
         
            +
            # Check whether Flash Attention is available (and use it by default)
         
     | 
| 
      
 23 
     | 
    
         
            +
            OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
         
     | 
| 
      
 24 
     | 
    
         
            +
            # A fallback setting to allow all available kernels if Flash Attention fails
         
     | 
| 
      
 25 
     | 
    
         
            +
            ALLOW_ALL_KERNELS = False
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
            def sdp_kernel_context(dropout_p):
         
     | 
| 
      
 29 
     | 
    
         
            +
                """
         
     | 
| 
      
 30 
     | 
    
         
            +
                Get the context for the attention scaled dot-product kernel. We use Flash Attention
         
     | 
| 
      
 31 
     | 
    
         
            +
                by default, but fall back to all available kernels if Flash Attention fails.
         
     | 
| 
      
 32 
     | 
    
         
            +
                """
         
     | 
| 
      
 33 
     | 
    
         
            +
                if ALLOW_ALL_KERNELS:
         
     | 
| 
      
 34 
     | 
    
         
            +
                    return contextlib.nullcontext()
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
                return torch.backends.cuda.sdp_kernel(
         
     | 
| 
      
 37 
     | 
    
         
            +
                    enable_flash=USE_FLASH_ATTN,
         
     | 
| 
      
 38 
     | 
    
         
            +
                    # if Flash attention kernel is off, then math kernel needs to be enabled
         
     | 
| 
      
 39 
     | 
    
         
            +
                    enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
         
     | 
| 
      
 40 
     | 
    
         
            +
                    enable_mem_efficient=OLD_GPU,
         
     | 
| 
      
 41 
     | 
    
         
            +
                )
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
            class TwoWayTransformer(nn.Module):
         
     | 
| 
      
 45 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 46 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 47 
     | 
    
         
            +
                    depth: int,
         
     | 
| 
      
 48 
     | 
    
         
            +
                    embedding_dim: int,
         
     | 
| 
      
 49 
     | 
    
         
            +
                    num_heads: int,
         
     | 
| 
      
 50 
     | 
    
         
            +
                    mlp_dim: int,
         
     | 
| 
      
 51 
     | 
    
         
            +
                    activation: Type[nn.Module] = nn.ReLU,
         
     | 
| 
      
 52 
     | 
    
         
            +
                    attention_downsample_rate: int = 2,
         
     | 
| 
      
 53 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 54 
     | 
    
         
            +
                    """
         
     | 
| 
      
 55 
     | 
    
         
            +
                    A transformer decoder that attends to an input image using
         
     | 
| 
      
 56 
     | 
    
         
            +
                    queries whose positional embedding is supplied.
         
     | 
| 
      
 57 
     | 
    
         
            +
             
     | 
| 
      
 58 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 59 
     | 
    
         
            +
                      depth (int): number of layers in the transformer
         
     | 
| 
      
 60 
     | 
    
         
            +
                      embedding_dim (int): the channel dimension for the input embeddings
         
     | 
| 
      
 61 
     | 
    
         
            +
                      num_heads (int): the number of heads for multihead attention. Must
         
     | 
| 
      
 62 
     | 
    
         
            +
                        divide embedding_dim
         
     | 
| 
      
 63 
     | 
    
         
            +
                      mlp_dim (int): the channel dimension internal to the MLP block
         
     | 
| 
      
 64 
     | 
    
         
            +
                      activation (nn.Module): the activation to use in the MLP block
         
     | 
| 
      
 65 
     | 
    
         
            +
                    """
         
     | 
| 
      
 66 
     | 
    
         
            +
                    super().__init__()
         
     | 
| 
      
 67 
     | 
    
         
            +
                    self.depth = depth
         
     | 
| 
      
 68 
     | 
    
         
            +
                    self.embedding_dim = embedding_dim
         
     | 
| 
      
 69 
     | 
    
         
            +
                    self.num_heads = num_heads
         
     | 
| 
      
 70 
     | 
    
         
            +
                    self.mlp_dim = mlp_dim
         
     | 
| 
      
 71 
     | 
    
         
            +
                    self.layers = nn.ModuleList()
         
     | 
| 
      
 72 
     | 
    
         
            +
             
     | 
| 
      
 73 
     | 
    
         
            +
                    for i in range(depth):
         
     | 
| 
      
 74 
     | 
    
         
            +
                        self.layers.append(
         
     | 
| 
      
 75 
     | 
    
         
            +
                            TwoWayAttentionBlock(
         
     | 
| 
      
 76 
     | 
    
         
            +
                                embedding_dim=embedding_dim,
         
     | 
| 
      
 77 
     | 
    
         
            +
                                num_heads=num_heads,
         
     | 
| 
      
 78 
     | 
    
         
            +
                                mlp_dim=mlp_dim,
         
     | 
| 
      
 79 
     | 
    
         
            +
                                activation=activation,
         
     | 
| 
      
 80 
     | 
    
         
            +
                                attention_downsample_rate=attention_downsample_rate,
         
     | 
| 
      
 81 
     | 
    
         
            +
                                skip_first_layer_pe=(i == 0),
         
     | 
| 
      
 82 
     | 
    
         
            +
                            )
         
     | 
| 
      
 83 
     | 
    
         
            +
                        )
         
     | 
| 
      
 84 
     | 
    
         
            +
             
     | 
| 
      
 85 
     | 
    
         
            +
                    self.final_attn_token_to_image = Attention(
         
     | 
| 
      
 86 
     | 
    
         
            +
                        embedding_dim, num_heads, downsample_rate=attention_downsample_rate
         
     | 
| 
      
 87 
     | 
    
         
            +
                    )
         
     | 
| 
      
 88 
     | 
    
         
            +
                    self.norm_final_attn = nn.LayerNorm(embedding_dim)
         
     | 
| 
      
 89 
     | 
    
         
            +
             
     | 
| 
      
 90 
     | 
    
         
            +
                def forward(
         
     | 
| 
      
 91 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 92 
     | 
    
         
            +
                    image_embedding: Tensor,
         
     | 
| 
      
 93 
     | 
    
         
            +
                    image_pe: Tensor,
         
     | 
| 
      
 94 
     | 
    
         
            +
                    point_embedding: Tensor,
         
     | 
| 
      
 95 
     | 
    
         
            +
                ) -> Tuple[Tensor, Tensor]:
         
     | 
| 
      
 96 
     | 
    
         
            +
                    """
         
     | 
| 
      
 97 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 98 
     | 
    
         
            +
                      image_embedding (torch.Tensor): image to attend to. Should be shape
         
     | 
| 
      
 99 
     | 
    
         
            +
                        B x embedding_dim x h x w for any h and w.
         
     | 
| 
      
 100 
     | 
    
         
            +
                      image_pe (torch.Tensor): the positional encoding to add to the image. Must
         
     | 
| 
      
 101 
     | 
    
         
            +
                        have the same shape as image_embedding.
         
     | 
| 
      
 102 
     | 
    
         
            +
                      point_embedding (torch.Tensor): the embedding to add to the query points.
         
     | 
| 
      
 103 
     | 
    
         
            +
                        Must have shape B x N_points x embedding_dim for any N_points.
         
     | 
| 
      
 104 
     | 
    
         
            +
             
     | 
| 
      
 105 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 106 
     | 
    
         
            +
                      torch.Tensor: the processed point_embedding
         
     | 
| 
      
 107 
     | 
    
         
            +
                      torch.Tensor: the processed image_embedding
         
     | 
| 
      
 108 
     | 
    
         
            +
                    """
         
     | 
| 
      
 109 
     | 
    
         
            +
                    # BxCxHxW -> BxHWxC == B x N_image_tokens x C
         
     | 
| 
      
 110 
     | 
    
         
            +
                    bs, c, h, w = image_embedding.shape
         
     | 
| 
      
 111 
     | 
    
         
            +
                    image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
         
     | 
| 
      
 112 
     | 
    
         
            +
                    image_pe = image_pe.flatten(2).permute(0, 2, 1)
         
     | 
| 
      
 113 
     | 
    
         
            +
             
     | 
| 
      
 114 
     | 
    
         
            +
                    # Prepare queries
         
     | 
| 
      
 115 
     | 
    
         
            +
                    queries = point_embedding
         
     | 
| 
      
 116 
     | 
    
         
            +
                    keys = image_embedding
         
     | 
| 
      
 117 
     | 
    
         
            +
             
     | 
| 
      
 118 
     | 
    
         
            +
                    # Apply transformer blocks and final layernorm
         
     | 
| 
      
 119 
     | 
    
         
            +
                    for layer in self.layers:
         
     | 
| 
      
 120 
     | 
    
         
            +
                        queries, keys = layer(
         
     | 
| 
      
 121 
     | 
    
         
            +
                            queries=queries,
         
     | 
| 
      
 122 
     | 
    
         
            +
                            keys=keys,
         
     | 
| 
      
 123 
     | 
    
         
            +
                            query_pe=point_embedding,
         
     | 
| 
      
 124 
     | 
    
         
            +
                            key_pe=image_pe,
         
     | 
| 
      
 125 
     | 
    
         
            +
                        )
         
     | 
| 
      
 126 
     | 
    
         
            +
             
     | 
| 
      
 127 
     | 
    
         
            +
                    # Apply the final attention layer from the points to the image
         
     | 
| 
      
 128 
     | 
    
         
            +
                    q = queries + point_embedding
         
     | 
| 
      
 129 
     | 
    
         
            +
                    k = keys + image_pe
         
     | 
| 
      
 130 
     | 
    
         
            +
                    attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
         
     | 
| 
      
 131 
     | 
    
         
            +
                    queries = queries + attn_out
         
     | 
| 
      
 132 
     | 
    
         
            +
                    queries = self.norm_final_attn(queries)
         
     | 
| 
      
 133 
     | 
    
         
            +
             
     | 
| 
      
 134 
     | 
    
         
            +
                    return queries, keys
         
     | 
| 
      
 135 
     | 
    
         
            +
             
     | 
| 
      
 136 
     | 
    
         
            +
             
     | 
| 
      
 137 
     | 
    
         
            +
            class TwoWayAttentionBlock(nn.Module):
         
     | 
| 
      
 138 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 139 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 140 
     | 
    
         
            +
                    embedding_dim: int,
         
     | 
| 
      
 141 
     | 
    
         
            +
                    num_heads: int,
         
     | 
| 
      
 142 
     | 
    
         
            +
                    mlp_dim: int = 2048,
         
     | 
| 
      
 143 
     | 
    
         
            +
                    activation: Type[nn.Module] = nn.ReLU,
         
     | 
| 
      
 144 
     | 
    
         
            +
                    attention_downsample_rate: int = 2,
         
     | 
| 
      
 145 
     | 
    
         
            +
                    skip_first_layer_pe: bool = False,
         
     | 
| 
      
 146 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 147 
     | 
    
         
            +
                    """
         
     | 
| 
      
 148 
     | 
    
         
            +
                    A transformer block with four layers: (1) self-attention of sparse
         
     | 
| 
      
 149 
     | 
    
         
            +
                    inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
         
     | 
| 
      
 150 
     | 
    
         
            +
                    block on sparse inputs, and (4) cross attention of dense inputs to sparse
         
     | 
| 
      
 151 
     | 
    
         
            +
                    inputs.
         
     | 
| 
      
 152 
     | 
    
         
            +
             
     | 
| 
      
 153 
     | 
    
         
            +
                    Arguments:
         
     | 
| 
      
 154 
     | 
    
         
            +
                      embedding_dim (int): the channel dimension of the embeddings
         
     | 
| 
      
 155 
     | 
    
         
            +
                      num_heads (int): the number of heads in the attention layers
         
     | 
| 
      
 156 
     | 
    
         
            +
                      mlp_dim (int): the hidden dimension of the mlp block
         
     | 
| 
      
 157 
     | 
    
         
            +
                      activation (nn.Module): the activation of the mlp block
         
     | 
| 
      
 158 
     | 
    
         
            +
                      skip_first_layer_pe (bool): skip the PE on the first layer
         
     | 
| 
      
 159 
     | 
    
         
            +
                    """
         
     | 
| 
      
 160 
     | 
    
         
            +
                    super().__init__()
         
     | 
| 
      
 161 
     | 
    
         
            +
                    self.self_attn = Attention(embedding_dim, num_heads)
         
     | 
| 
      
 162 
     | 
    
         
            +
                    self.norm1 = nn.LayerNorm(embedding_dim)
         
     | 
| 
      
 163 
     | 
    
         
            +
             
     | 
| 
      
 164 
     | 
    
         
            +
                    self.cross_attn_token_to_image = Attention(
         
     | 
| 
      
 165 
     | 
    
         
            +
                        embedding_dim, num_heads, downsample_rate=attention_downsample_rate
         
     | 
| 
      
 166 
     | 
    
         
            +
                    )
         
     | 
| 
      
 167 
     | 
    
         
            +
                    self.norm2 = nn.LayerNorm(embedding_dim)
         
     | 
| 
      
 168 
     | 
    
         
            +
             
     | 
| 
      
 169 
     | 
    
         
            +
                    self.mlp = MLP(
         
     | 
| 
      
 170 
     | 
    
         
            +
                        embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
         
     | 
| 
      
 171 
     | 
    
         
            +
                    )
         
     | 
| 
      
 172 
     | 
    
         
            +
                    self.norm3 = nn.LayerNorm(embedding_dim)
         
     | 
| 
      
 173 
     | 
    
         
            +
             
     | 
| 
      
 174 
     | 
    
         
            +
                    self.norm4 = nn.LayerNorm(embedding_dim)
         
     | 
| 
      
 175 
     | 
    
         
            +
                    self.cross_attn_image_to_token = Attention(
         
     | 
| 
      
 176 
     | 
    
         
            +
                        embedding_dim, num_heads, downsample_rate=attention_downsample_rate
         
     | 
| 
      
 177 
     | 
    
         
            +
                    )
         
     | 
| 
      
 178 
     | 
    
         
            +
             
     | 
| 
      
 179 
     | 
    
         
            +
                    self.skip_first_layer_pe = skip_first_layer_pe
         
     | 
| 
      
 180 
     | 
    
         
            +
             
     | 
| 
      
 181 
     | 
    
         
            +
                def forward(
         
     | 
| 
      
 182 
     | 
    
         
            +
                    self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
         
     | 
| 
      
 183 
     | 
    
         
            +
                ) -> Tuple[Tensor, Tensor]:
         
     | 
| 
      
 184 
     | 
    
         
            +
                    # Self attention block
         
     | 
| 
      
 185 
     | 
    
         
            +
                    if self.skip_first_layer_pe:
         
     | 
| 
      
 186 
     | 
    
         
            +
                        queries = self.self_attn(q=queries, k=queries, v=queries)
         
     | 
| 
      
 187 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 188 
     | 
    
         
            +
                        q = queries + query_pe
         
     | 
| 
      
 189 
     | 
    
         
            +
                        attn_out = self.self_attn(q=q, k=q, v=queries)
         
     | 
| 
      
 190 
     | 
    
         
            +
                        queries = queries + attn_out
         
     | 
| 
      
 191 
     | 
    
         
            +
                    queries = self.norm1(queries)
         
     | 
| 
      
 192 
     | 
    
         
            +
             
     | 
| 
      
 193 
     | 
    
         
            +
                    # Cross attention block, tokens attending to image embedding
         
     | 
| 
      
 194 
     | 
    
         
            +
                    q = queries + query_pe
         
     | 
| 
      
 195 
     | 
    
         
            +
                    k = keys + key_pe
         
     | 
| 
      
 196 
     | 
    
         
            +
                    attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
         
     | 
| 
      
 197 
     | 
    
         
            +
                    queries = queries + attn_out
         
     | 
| 
      
 198 
     | 
    
         
            +
                    queries = self.norm2(queries)
         
     | 
| 
      
 199 
     | 
    
         
            +
             
     | 
| 
      
 200 
     | 
    
         
            +
                    # MLP block
         
     | 
| 
      
 201 
     | 
    
         
            +
                    mlp_out = self.mlp(queries)
         
     | 
| 
      
 202 
     | 
    
         
            +
                    queries = queries + mlp_out
         
     | 
| 
      
 203 
     | 
    
         
            +
                    queries = self.norm3(queries)
         
     | 
| 
      
 204 
     | 
    
         
            +
             
     | 
| 
      
 205 
     | 
    
         
            +
                    # Cross attention block, image embedding attending to tokens
         
     | 
| 
      
 206 
     | 
    
         
            +
                    q = queries + query_pe
         
     | 
| 
      
 207 
     | 
    
         
            +
                    k = keys + key_pe
         
     | 
| 
      
 208 
     | 
    
         
            +
                    attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
         
     | 
| 
      
 209 
     | 
    
         
            +
                    keys = keys + attn_out
         
     | 
| 
      
 210 
     | 
    
         
            +
                    keys = self.norm4(keys)
         
     | 
| 
      
 211 
     | 
    
         
            +
             
     | 
| 
      
 212 
     | 
    
         
            +
                    return queries, keys
         
     | 
| 
      
 213 
     | 
    
         
            +
             
     | 
| 
      
 214 
     | 
    
         
            +
             
     | 
| 
      
 215 
     | 
    
         
            +
            class Attention(nn.Module):
         
     | 
| 
      
 216 
     | 
    
         
            +
                """
         
     | 
| 
      
 217 
     | 
    
         
            +
                An attention layer that allows for downscaling the size of the embedding
         
     | 
| 
      
 218 
     | 
    
         
            +
                after projection to queries, keys, and values.
         
     | 
| 
      
 219 
     | 
    
         
            +
                """
         
     | 
| 
      
 220 
     | 
    
         
            +
             
     | 
| 
      
 221 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 222 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 223 
     | 
    
         
            +
                    embedding_dim: int,
         
     | 
| 
      
 224 
     | 
    
         
            +
                    num_heads: int,
         
     | 
| 
      
 225 
     | 
    
         
            +
                    downsample_rate: int = 1,
         
     | 
| 
      
 226 
     | 
    
         
            +
                    dropout: float = 0.0,
         
     | 
| 
      
 227 
     | 
    
         
            +
                    kv_in_dim: int = None,
         
     | 
| 
      
 228 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 229 
     | 
    
         
            +
                    super().__init__()
         
     | 
| 
      
 230 
     | 
    
         
            +
                    self.embedding_dim = embedding_dim
         
     | 
| 
      
 231 
     | 
    
         
            +
                    self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
         
     | 
| 
      
 232 
     | 
    
         
            +
                    self.internal_dim = embedding_dim // downsample_rate
         
     | 
| 
      
 233 
     | 
    
         
            +
                    self.num_heads = num_heads
         
     | 
| 
      
 234 
     | 
    
         
            +
                    assert (
         
     | 
| 
      
 235 
     | 
    
         
            +
                        self.internal_dim % num_heads == 0
         
     | 
| 
      
 236 
     | 
    
         
            +
                    ), "num_heads must divide embedding_dim."
         
     | 
| 
      
 237 
     | 
    
         
            +
             
     | 
| 
      
 238 
     | 
    
         
            +
                    self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
         
     | 
| 
      
 239 
     | 
    
         
            +
                    self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
         
     | 
| 
      
 240 
     | 
    
         
            +
                    self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
         
     | 
| 
      
 241 
     | 
    
         
            +
                    self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
         
     | 
| 
      
 242 
     | 
    
         
            +
             
     | 
| 
      
 243 
     | 
    
         
            +
                    self.dropout_p = dropout
         
     | 
| 
      
 244 
     | 
    
         
            +
             
     | 
| 
      
 245 
     | 
    
         
            +
                def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
         
     | 
| 
      
 246 
     | 
    
         
            +
                    b, n, c = x.shape
         
     | 
| 
      
 247 
     | 
    
         
            +
                    x = x.reshape(b, n, num_heads, c // num_heads)
         
     | 
| 
      
 248 
     | 
    
         
            +
                    return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head
         
     | 
| 
      
 249 
     | 
    
         
            +
             
     | 
| 
      
 250 
     | 
    
         
            +
                def _recombine_heads(self, x: Tensor) -> Tensor:
         
     | 
| 
      
 251 
     | 
    
         
            +
                    b, n_heads, n_tokens, c_per_head = x.shape
         
     | 
| 
      
 252 
     | 
    
         
            +
                    x = x.transpose(1, 2)
         
     | 
| 
      
 253 
     | 
    
         
            +
                    return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C
         
     | 
| 
      
 254 
     | 
    
         
            +
             
     | 
| 
      
 255 
     | 
    
         
            +
                def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
         
     | 
| 
      
 256 
     | 
    
         
            +
                    # Input projections
         
     | 
| 
      
 257 
     | 
    
         
            +
                    q = self.q_proj(q)
         
     | 
| 
      
 258 
     | 
    
         
            +
                    k = self.k_proj(k)
         
     | 
| 
      
 259 
     | 
    
         
            +
                    v = self.v_proj(v)
         
     | 
| 
      
 260 
     | 
    
         
            +
             
     | 
| 
      
 261 
     | 
    
         
            +
                    # Separate into heads
         
     | 
| 
      
 262 
     | 
    
         
            +
                    q = self._separate_heads(q, self.num_heads)
         
     | 
| 
      
 263 
     | 
    
         
            +
                    k = self._separate_heads(k, self.num_heads)
         
     | 
| 
      
 264 
     | 
    
         
            +
                    v = self._separate_heads(v, self.num_heads)
         
     | 
| 
      
 265 
     | 
    
         
            +
             
     | 
| 
      
 266 
     | 
    
         
            +
                    dropout_p = self.dropout_p if self.training else 0.0
         
     | 
| 
      
 267 
     | 
    
         
            +
                    # Attention
         
     | 
| 
      
 268 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 269 
     | 
    
         
            +
                        with sdp_kernel_context(dropout_p):
         
     | 
| 
      
 270 
     | 
    
         
            +
                            out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
         
     | 
| 
      
 271 
     | 
    
         
            +
                    except Exception as e:
         
     | 
| 
      
 272 
     | 
    
         
            +
                        # Fall back to all kernels if the Flash attention kernel fails
         
     | 
| 
      
 273 
     | 
    
         
            +
                        warnings.warn(
         
     | 
| 
      
 274 
     | 
    
         
            +
                            f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
         
     | 
| 
      
 275 
     | 
    
         
            +
                            f"kernels for scaled_dot_product_attention (which may have a slower speed).",
         
     | 
| 
      
 276 
     | 
    
         
            +
                            category=UserWarning,
         
     | 
| 
      
 277 
     | 
    
         
            +
                            stacklevel=2,
         
     | 
| 
      
 278 
     | 
    
         
            +
                        )
         
     | 
| 
      
 279 
     | 
    
         
            +
                        global ALLOW_ALL_KERNELS
         
     | 
| 
      
 280 
     | 
    
         
            +
                        ALLOW_ALL_KERNELS = True
         
     | 
| 
      
 281 
     | 
    
         
            +
                        out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
         
     | 
| 
      
 282 
     | 
    
         
            +
             
     | 
| 
      
 283 
     | 
    
         
            +
                    out = self._recombine_heads(out)
         
     | 
| 
      
 284 
     | 
    
         
            +
                    out = self.out_proj(out)
         
     | 
| 
      
 285 
     | 
    
         
            +
             
     | 
| 
      
 286 
     | 
    
         
            +
                    return out
         
     | 
| 
      
 287 
     | 
    
         
            +
             
     | 
| 
      
 288 
     | 
    
         
            +
             
     | 
| 
      
 289 
     | 
    
         
            +
            class RoPEAttention(Attention):
         
     | 
| 
      
 290 
     | 
    
         
            +
                """Attention with rotary position encoding."""
         
     | 
| 
      
 291 
     | 
    
         
            +
             
     | 
| 
      
 292 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 293 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 294 
     | 
    
         
            +
                    *args,
         
     | 
| 
      
 295 
     | 
    
         
            +
                    rope_theta=10000.0,
         
     | 
| 
      
 296 
     | 
    
         
            +
                    # whether to repeat q rope to match k length
         
     | 
| 
      
 297 
     | 
    
         
            +
                    # this is needed for cross-attention to memories
         
     | 
| 
      
 298 
     | 
    
         
            +
                    rope_k_repeat=False,
         
     | 
| 
      
 299 
     | 
    
         
            +
                    feat_sizes=(32, 32),  # [w, h] for stride 16 feats at 512 resolution
         
     | 
| 
      
 300 
     | 
    
         
            +
                    **kwargs,
         
     | 
| 
      
 301 
     | 
    
         
            +
                ):
         
     | 
| 
      
 302 
     | 
    
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 
      
 303 
     | 
    
         
            +
             
     | 
| 
      
 304 
     | 
    
         
            +
                    self.compute_cis = partial(
         
     | 
| 
      
 305 
     | 
    
         
            +
                        compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
         
     | 
| 
      
 306 
     | 
    
         
            +
                    )
         
     | 
| 
      
 307 
     | 
    
         
            +
                    freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
         
     | 
| 
      
 308 
     | 
    
         
            +
                    self.freqs_cis = freqs_cis
         
     | 
| 
      
 309 
     | 
    
         
            +
                    self.rope_k_repeat = rope_k_repeat
         
     | 
| 
      
 310 
     | 
    
         
            +
             
     | 
| 
      
 311 
     | 
    
         
            +
                def forward(
         
     | 
| 
      
 312 
     | 
    
         
            +
                    self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
         
     | 
| 
      
 313 
     | 
    
         
            +
                ) -> Tensor:
         
     | 
| 
      
 314 
     | 
    
         
            +
                    # Input projections
         
     | 
| 
      
 315 
     | 
    
         
            +
                    q = self.q_proj(q)
         
     | 
| 
      
 316 
     | 
    
         
            +
                    k = self.k_proj(k)
         
     | 
| 
      
 317 
     | 
    
         
            +
                    v = self.v_proj(v)
         
     | 
| 
      
 318 
     | 
    
         
            +
             
     | 
| 
      
 319 
     | 
    
         
            +
                    # Separate into heads
         
     | 
| 
      
 320 
     | 
    
         
            +
                    q = self._separate_heads(q, self.num_heads)
         
     | 
| 
      
 321 
     | 
    
         
            +
                    k = self._separate_heads(k, self.num_heads)
         
     | 
| 
      
 322 
     | 
    
         
            +
                    v = self._separate_heads(v, self.num_heads)
         
     | 
| 
      
 323 
     | 
    
         
            +
             
     | 
| 
      
 324 
     | 
    
         
            +
                    # Apply rotary position encoding
         
     | 
| 
      
 325 
     | 
    
         
            +
                    w = h = math.sqrt(q.shape[-2])
         
     | 
| 
      
 326 
     | 
    
         
            +
                    self.freqs_cis = self.freqs_cis.to(q.device)
         
     | 
| 
      
 327 
     | 
    
         
            +
                    if self.freqs_cis.shape[0] != q.shape[-2]:
         
     | 
| 
      
 328 
     | 
    
         
            +
                        self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
         
     | 
| 
      
 329 
     | 
    
         
            +
                    if q.shape[-2] != k.shape[-2]:
         
     | 
| 
      
 330 
     | 
    
         
            +
                        assert self.rope_k_repeat
         
     | 
| 
      
 331 
     | 
    
         
            +
             
     | 
| 
      
 332 
     | 
    
         
            +
                    num_k_rope = k.size(-2) - num_k_exclude_rope
         
     | 
| 
      
 333 
     | 
    
         
            +
                    q, k[:, :, :num_k_rope] = apply_rotary_enc(
         
     | 
| 
      
 334 
     | 
    
         
            +
                        q,
         
     | 
| 
      
 335 
     | 
    
         
            +
                        k[:, :, :num_k_rope],
         
     | 
| 
      
 336 
     | 
    
         
            +
                        freqs_cis=self.freqs_cis,
         
     | 
| 
      
 337 
     | 
    
         
            +
                        repeat_freqs_k=self.rope_k_repeat,
         
     | 
| 
      
 338 
     | 
    
         
            +
                    )
         
     | 
| 
      
 339 
     | 
    
         
            +
             
     | 
| 
      
 340 
     | 
    
         
            +
                    dropout_p = self.dropout_p if self.training else 0.0
         
     | 
| 
      
 341 
     | 
    
         
            +
                    # Attention
         
     | 
| 
      
 342 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 343 
     | 
    
         
            +
                        with sdp_kernel_context(dropout_p):
         
     | 
| 
      
 344 
     | 
    
         
            +
                            out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
         
     | 
| 
      
 345 
     | 
    
         
            +
                    except Exception as e:
         
     | 
| 
      
 346 
     | 
    
         
            +
                        # Fall back to all kernels if the Flash attention kernel fails
         
     | 
| 
      
 347 
     | 
    
         
            +
                        warnings.warn(
         
     | 
| 
      
 348 
     | 
    
         
            +
                            f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
         
     | 
| 
      
 349 
     | 
    
         
            +
                            f"kernels for scaled_dot_product_attention (which may have a slower speed).",
         
     | 
| 
      
 350 
     | 
    
         
            +
                            category=UserWarning,
         
     | 
| 
      
 351 
     | 
    
         
            +
                            stacklevel=2,
         
     | 
| 
      
 352 
     | 
    
         
            +
                        )
         
     | 
| 
      
 353 
     | 
    
         
            +
                        global ALLOW_ALL_KERNELS
         
     | 
| 
      
 354 
     | 
    
         
            +
                        ALLOW_ALL_KERNELS = True
         
     | 
| 
      
 355 
     | 
    
         
            +
                        out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
         
     | 
| 
      
 356 
     | 
    
         
            +
             
     | 
| 
      
 357 
     | 
    
         
            +
                    out = self._recombine_heads(out)
         
     | 
| 
      
 358 
     | 
    
         
            +
                    out = self.out_proj(out)
         
     | 
| 
      
 359 
     | 
    
         
            +
             
     | 
| 
      
 360 
     | 
    
         
            +
                    return out
         
     |