titans-pytorch 0.0.26__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +23 -2
- {titans_pytorch-0.0.26.dist-info → titans_pytorch-0.0.27.dist-info}/METADATA +1 -1
- {titans_pytorch-0.0.26.dist-info → titans_pytorch-0.0.27.dist-info}/RECORD +5 -5
- {titans_pytorch-0.0.26.dist-info → titans_pytorch-0.0.27.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.26.dist-info → titans_pytorch-0.0.27.dist-info}/licenses/LICENSE +0 -0
| @@ -3,10 +3,11 @@ import math | |
| 3 3 | 
             
            from functools import partial
         | 
| 4 4 |  | 
| 5 5 | 
             
            import torch
         | 
| 6 | 
            -
            from torch import nn
         | 
| 6 | 
            +
            from torch import nn, cat
         | 
| 7 7 | 
             
            import torch.nn.functional as F
         | 
| 8 8 | 
             
            from torch.nn import Module, ModuleList, Linear
         | 
| 9 9 |  | 
| 10 | 
            +
            from einops import repeat
         | 
| 10 11 | 
             
            from einops.layers.torch import Rearrange
         | 
| 11 12 |  | 
| 12 13 | 
             
            from hyper_connections import get_init_and_expand_reduce_stream_functions
         | 
| @@ -48,6 +49,7 @@ class SegmentedAttention(Module): | |
| 48 49 | 
             
                    self,
         | 
| 49 50 | 
             
                    dim,
         | 
| 50 51 | 
             
                    segment_len,
         | 
| 52 | 
            +
                    num_persist_mem_tokens,
         | 
| 51 53 | 
             
                    dim_head = 64,
         | 
| 52 54 | 
             
                    heads = 8,
         | 
| 53 55 | 
             
                ):
         | 
| @@ -67,6 +69,7 @@ class SegmentedAttention(Module): | |
| 67 69 | 
             
                    self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
         | 
| 68 70 | 
             
                    self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
         | 
| 69 71 |  | 
| 72 | 
            +
                    self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
         | 
| 70 73 |  | 
| 71 74 | 
             
                def forward(self, seq):
         | 
| 72 75 | 
             
                    batch, seq_len = seq.shape[:2]
         | 
| @@ -92,6 +95,15 @@ class SegmentedAttention(Module): | |
| 92 95 | 
             
                    q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
         | 
| 93 96 | 
             
                    q, k, v = map(self.split_heads, (q, k, v))
         | 
| 94 97 |  | 
| 98 | 
            +
                    # take care of persistent memory key / values
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    k = cat((pmk, k), dim = -2)
         | 
| 103 | 
            +
                    v = cat((pmv, v), dim = -2)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    # sdpa
         | 
| 106 | 
            +
             | 
| 95 107 | 
             
                    out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
         | 
| 96 108 |  | 
| 97 109 | 
             
                    out = self.merge_heads(out)
         | 
| @@ -113,6 +125,7 @@ class MemoryAsContextTransformer(Module): | |
| 113 125 | 
             
                    dim,
         | 
| 114 126 | 
             
                    depth,
         | 
| 115 127 | 
             
                    segment_len,
         | 
| 128 | 
            +
                    num_persist_mem_tokens,
         | 
| 116 129 | 
             
                    dim_head = 64,
         | 
| 117 130 | 
             
                    heads = 8,
         | 
| 118 131 | 
             
                    ff_mult = 4,
         | 
| @@ -127,7 +140,14 @@ class MemoryAsContextTransformer(Module): | |
| 127 140 | 
             
                    self.layers = ModuleList([])
         | 
| 128 141 |  | 
| 129 142 | 
             
                    for _ in range(depth):
         | 
| 130 | 
            -
                        attn = SegmentedAttention( | 
| 143 | 
            +
                        attn = SegmentedAttention(
         | 
| 144 | 
            +
                            dim = dim,
         | 
| 145 | 
            +
                            dim_head = dim_head,
         | 
| 146 | 
            +
                            heads = heads,
         | 
| 147 | 
            +
                            segment_len = segment_len,
         | 
| 148 | 
            +
                            num_persist_mem_tokens = num_persist_mem_tokens
         | 
| 149 | 
            +
                        )
         | 
| 150 | 
            +
             | 
| 131 151 | 
             
                        ff = FeedForward(dim = dim, mult = ff_mult)
         | 
| 132 152 |  | 
| 133 153 | 
             
                        self.layers.append(ModuleList([
         | 
| @@ -162,6 +182,7 @@ if __name__ == '__main__': | |
| 162 182 | 
             
                    num_tokens = 256,
         | 
| 163 183 | 
             
                    dim = 256,
         | 
| 164 184 | 
             
                    depth = 2,
         | 
| 185 | 
            +
                    num_persist_mem_tokens = 16,
         | 
| 165 186 | 
             
                    segment_len = 128,
         | 
| 166 187 | 
             
                )
         | 
| 167 188 |  | 
| @@ -1,9 +1,9 @@ | |
| 1 1 | 
             
            titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
         | 
| 2 2 | 
             
            titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
         | 
| 3 | 
            -
            titans_pytorch/mac_transformer.py,sha256= | 
| 3 | 
            +
            titans_pytorch/mac_transformer.py,sha256=4xUSfGDdVsR-WmeXX7yXoFfybROvNCjOxL_EHDJ_Wlk,4681
         | 
| 4 4 | 
             
            titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
         | 
| 5 5 | 
             
            titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
         | 
| 6 | 
            -
            titans_pytorch-0.0. | 
| 7 | 
            -
            titans_pytorch-0.0. | 
| 8 | 
            -
            titans_pytorch-0.0. | 
| 9 | 
            -
            titans_pytorch-0.0. | 
| 6 | 
            +
            titans_pytorch-0.0.27.dist-info/METADATA,sha256=qJp1IDbphEUfW7EyNvQ7RfmHuvB7SH5h_tlnCVwV4EY,3851
         | 
| 7 | 
            +
            titans_pytorch-0.0.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
         | 
| 8 | 
            +
            titans_pytorch-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
         | 
| 9 | 
            +
            titans_pytorch-0.0.27.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         |