titans-pytorch 0.0.26__tar.gz → 0.0.29__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/PKG-INFO +3 -1
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/pyproject.toml +3 -1
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/tests/test_titans.py +16 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/titans_pytorch/mac_transformer.py +52 -4
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/.gitignore +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/LICENSE +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/README.md +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/data/README.md +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/fig1.png +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/fig2.png +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/requirements.txt +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/train.py +0 -0
| @@ -1,6 +1,6 @@ | |
| 1 1 | 
             
            Metadata-Version: 2.4
         | 
| 2 2 | 
             
            Name: titans-pytorch
         | 
| 3 | 
            -
            Version: 0.0. | 
| 3 | 
            +
            Version: 0.0.29
         | 
| 4 4 | 
             
            Summary: Titans
         | 
| 5 5 | 
             
            Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
         | 
| 6 6 | 
             
            Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
         | 
| @@ -35,10 +35,12 @@ Classifier: Programming Language :: Python :: 3.9 | |
| 35 35 | 
             
            Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
         | 
| 36 36 | 
             
            Requires-Python: >=3.9
         | 
| 37 37 | 
             
            Requires-Dist: accelerated-scan>=0.2.0
         | 
| 38 | 
            +
            Requires-Dist: axial-positional-embedding>=0.3.5
         | 
| 38 39 | 
             
            Requires-Dist: einops>=0.8.0
         | 
| 39 40 | 
             
            Requires-Dist: einx>=0.3.0
         | 
| 40 41 | 
             
            Requires-Dist: hyper-connections>=0.1.8
         | 
| 41 42 | 
             
            Requires-Dist: ninja
         | 
| 43 | 
            +
            Requires-Dist: rotary-embedding-torch
         | 
| 42 44 | 
             
            Requires-Dist: tensordict
         | 
| 43 45 | 
             
            Requires-Dist: torch>=2.2
         | 
| 44 46 | 
             
            Provides-Extra: examples
         | 
| @@ -1,6 +1,6 @@ | |
| 1 1 | 
             
            [project]
         | 
| 2 2 | 
             
            name = "titans-pytorch"
         | 
| 3 | 
            -
            version = "0.0. | 
| 3 | 
            +
            version = "0.0.29"
         | 
| 4 4 | 
             
            description = "Titans"
         | 
| 5 5 | 
             
            authors = [
         | 
| 6 6 | 
             
                { name = "Phil Wang", email = "lucidrains@gmail.com" }
         | 
| @@ -26,10 +26,12 @@ classifiers=[ | |
| 26 26 |  | 
| 27 27 | 
             
            dependencies = [
         | 
| 28 28 | 
             
                "accelerated-scan>=0.2.0",
         | 
| 29 | 
            +
                "axial_positional_embedding>=0.3.5",
         | 
| 29 30 | 
             
                "einx>=0.3.0",
         | 
| 30 31 | 
             
                "einops>=0.8.0",
         | 
| 31 32 | 
             
                "hyper-connections>=0.1.8",
         | 
| 32 33 | 
             
                "Ninja",
         | 
| 34 | 
            +
                "rotary-embedding-torch",
         | 
| 33 35 | 
             
                "tensordict",
         | 
| 34 36 | 
             
                "torch>=2.2",
         | 
| 35 37 | 
             
            ]
         | 
| @@ -33,3 +33,19 @@ def test_titans_attn_memory(): | |
| 33 33 | 
             
                retrieved = mem(seq)
         | 
| 34 34 |  | 
| 35 35 | 
             
                assert seq.shape == retrieved.shape
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            def test_mac():
         | 
| 38 | 
            +
                from titans_pytorch.mac_transformer import MemoryAsContextTransformer
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                transformer = MemoryAsContextTransformer(
         | 
| 41 | 
            +
                    num_tokens = 256,
         | 
| 42 | 
            +
                    dim = 256,
         | 
| 43 | 
            +
                    depth = 2,
         | 
| 44 | 
            +
                    num_persist_mem_tokens = 16,
         | 
| 45 | 
            +
                    segment_len = 128,
         | 
| 46 | 
            +
                )
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                x = torch.randint(0, 256, (1, 1023))
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                logits = transformer(x)
         | 
| 51 | 
            +
                assert logits.shape == (1, 1023, 256)
         | 
| @@ -1,16 +1,23 @@ | |
| 1 1 | 
             
            from __future__ import annotations
         | 
| 2 | 
            -
            import  | 
| 2 | 
            +
            from math import ceil
         | 
| 3 3 | 
             
            from functools import partial
         | 
| 4 4 |  | 
| 5 5 | 
             
            import torch
         | 
| 6 | 
            -
            from torch import nn
         | 
| 6 | 
            +
            from torch import nn, cat
         | 
| 7 7 | 
             
            import torch.nn.functional as F
         | 
| 8 8 | 
             
            from torch.nn import Module, ModuleList, Linear
         | 
| 9 9 |  | 
| 10 | 
            +
            from einops import repeat
         | 
| 10 11 | 
             
            from einops.layers.torch import Rearrange
         | 
| 11 12 |  | 
| 13 | 
            +
             | 
| 12 14 | 
             
            from hyper_connections import get_init_and_expand_reduce_stream_functions
         | 
| 13 15 |  | 
| 16 | 
            +
            # absolute and relative positions
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from axial_positional_embedding import ContinuousAxialPositionalEmbedding
         | 
| 19 | 
            +
            from rotary_embedding_torch import RotaryEmbedding
         | 
| 20 | 
            +
             | 
| 14 21 | 
             
            # constants
         | 
| 15 22 |  | 
| 16 23 | 
             
            LinearNoBias = partial(Linear, bias = False)
         | 
| @@ -24,7 +31,7 @@ def default(v, d): | |
| 24 31 | 
             
                return v if exists(v) else d
         | 
| 25 32 |  | 
| 26 33 | 
             
            def round_up_multiple(seq, mult):
         | 
| 27 | 
            -
                return  | 
| 34 | 
            +
                return ceil(seq / mult) * mult
         | 
| 28 35 |  | 
| 29 36 | 
             
            # feedforward and attention
         | 
| 30 37 |  | 
| @@ -48,6 +55,7 @@ class SegmentedAttention(Module): | |
| 48 55 | 
             
                    self,
         | 
| 49 56 | 
             
                    dim,
         | 
| 50 57 | 
             
                    segment_len,
         | 
| 58 | 
            +
                    num_persist_mem_tokens,
         | 
| 51 59 | 
             
                    dim_head = 64,
         | 
| 52 60 | 
             
                    heads = 8,
         | 
| 53 61 | 
             
                ):
         | 
| @@ -56,6 +64,8 @@ class SegmentedAttention(Module): | |
| 56 64 |  | 
| 57 65 | 
             
                    dim_inner = dim_head * heads
         | 
| 58 66 |  | 
| 67 | 
            +
                    self.rotary_emb = RotaryEmbedding(dim_head)
         | 
| 68 | 
            +
             | 
| 59 69 | 
             
                    self.to_qkv = LinearNoBias(dim, dim_inner * 3)
         | 
| 60 70 | 
             
                    self.to_out = LinearNoBias(dim_inner, dim)
         | 
| 61 71 |  | 
| @@ -67,6 +77,7 @@ class SegmentedAttention(Module): | |
| 67 77 | 
             
                    self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
         | 
| 68 78 | 
             
                    self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
         | 
| 69 79 |  | 
| 80 | 
            +
                    self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
         | 
| 70 81 |  | 
| 71 82 | 
             
                def forward(self, seq):
         | 
| 72 83 | 
             
                    batch, seq_len = seq.shape[:2]
         | 
| @@ -92,6 +103,21 @@ class SegmentedAttention(Module): | |
| 92 103 | 
             
                    q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
         | 
| 93 104 | 
             
                    q, k, v = map(self.split_heads, (q, k, v))
         | 
| 94 105 |  | 
| 106 | 
            +
                    # take care of persistent memory key / values
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    # relative positions
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    # persistent memory
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    k = cat((pmk, k), dim = -2)
         | 
| 117 | 
            +
                    v = cat((pmv, v), dim = -2)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    # sdpa
         | 
| 120 | 
            +
             | 
| 95 121 | 
             
                    out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
         | 
| 96 122 |  | 
| 97 123 | 
             
                    out = self.merge_heads(out)
         | 
| @@ -113,6 +139,7 @@ class MemoryAsContextTransformer(Module): | |
| 113 139 | 
             
                    dim,
         | 
| 114 140 | 
             
                    depth,
         | 
| 115 141 | 
             
                    segment_len,
         | 
| 142 | 
            +
                    num_persist_mem_tokens,
         | 
| 116 143 | 
             
                    dim_head = 64,
         | 
| 117 144 | 
             
                    heads = 8,
         | 
| 118 145 | 
             
                    ff_mult = 4,
         | 
| @@ -120,6 +147,9 @@ class MemoryAsContextTransformer(Module): | |
| 120 147 | 
             
                ):
         | 
| 121 148 | 
             
                    super().__init__()
         | 
| 122 149 |  | 
| 150 | 
            +
                    self.segment_len = segment_len
         | 
| 151 | 
            +
                    self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
         | 
| 152 | 
            +
             | 
| 123 153 | 
             
                    self.token_emb = nn.Embedding(num_tokens, dim)
         | 
| 124 154 |  | 
| 125 155 | 
             
                    init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
         | 
| @@ -127,7 +157,14 @@ class MemoryAsContextTransformer(Module): | |
| 127 157 | 
             
                    self.layers = ModuleList([])
         | 
| 128 158 |  | 
| 129 159 | 
             
                    for _ in range(depth):
         | 
| 130 | 
            -
                        attn = SegmentedAttention( | 
| 160 | 
            +
                        attn = SegmentedAttention(
         | 
| 161 | 
            +
                            dim = dim,
         | 
| 162 | 
            +
                            dim_head = dim_head,
         | 
| 163 | 
            +
                            heads = heads,
         | 
| 164 | 
            +
                            segment_len = segment_len,
         | 
| 165 | 
            +
                            num_persist_mem_tokens = num_persist_mem_tokens
         | 
| 166 | 
            +
                        )
         | 
| 167 | 
            +
             | 
| 131 168 | 
             
                        ff = FeedForward(dim = dim, mult = ff_mult)
         | 
| 132 169 |  | 
| 133 170 | 
             
                        self.layers.append(ModuleList([
         | 
| @@ -140,9 +177,19 @@ class MemoryAsContextTransformer(Module): | |
| 140 177 | 
             
                    self.to_logits = LinearNoBias(dim, num_tokens)
         | 
| 141 178 |  | 
| 142 179 | 
             
                def forward(self, x):
         | 
| 180 | 
            +
                    seq_len, segment_len = x.shape[-1], self.segment_len
         | 
| 181 | 
            +
                    windows = ceil(seq_len / segment_len)
         | 
| 143 182 |  | 
| 144 183 | 
             
                    x = self.token_emb(x)
         | 
| 145 184 |  | 
| 185 | 
            +
                    # apply axial positional embedding
         | 
| 186 | 
            +
                    # so intra and inter segment can be more easily discerned by the network
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    pos_emb = self.axial_pos_emb((windows, segment_len), flatten = True)
         | 
| 189 | 
            +
                    x = x + pos_emb[:seq_len]
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    # expand and reduce streams for hyper connections
         | 
| 192 | 
            +
             | 
| 146 193 | 
             
                    x = self.expand_streams(x)
         | 
| 147 194 |  | 
| 148 195 | 
             
                    for attn, ff in self.layers:
         | 
| @@ -162,6 +209,7 @@ if __name__ == '__main__': | |
| 162 209 | 
             
                    num_tokens = 256,
         | 
| 163 210 | 
             
                    dim = 256,
         | 
| 164 211 | 
             
                    depth = 2,
         | 
| 212 | 
            +
                    num_persist_mem_tokens = 16,
         | 
| 165 213 | 
             
                    segment_len = 128,
         | 
| 166 214 | 
             
                )
         | 
| 167 215 |  | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         |