titans-pytorch 0.0.36__tar.gz → 0.0.38__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/PKG-INFO +1 -1
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/pyproject.toml +1 -1
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/mac_transformer.py +10 -18
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/titans.py +9 -6
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/train_mac.py +5 -5
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/.gitignore +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/LICENSE +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/README.md +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/data/README.md +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/fig1.png +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/fig2.png +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/requirements.txt +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/train.py +0 -0
| @@ -7,7 +7,7 @@ from torch import nn, cat | |
| 7 7 | 
             
            import torch.nn.functional as F
         | 
| 8 8 | 
             
            from torch.nn import Module, ModuleList, Linear
         | 
| 9 9 |  | 
| 10 | 
            -
            from einops import repeat, rearrange
         | 
| 10 | 
            +
            from einops import repeat, rearrange, pack, unpack
         | 
| 11 11 | 
             
            from einops.layers.torch import Rearrange
         | 
| 12 12 |  | 
| 13 13 | 
             
            from hyper_connections import get_init_and_expand_reduce_stream_functions
         | 
| @@ -214,7 +214,12 @@ class MemoryAsContextTransformer(Module): | |
| 214 214 | 
             
                        if layer in neural_memory_layers:
         | 
| 215 215 | 
             
                            assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
         | 
| 216 216 |  | 
| 217 | 
            -
                            mem = NeuralMemory( | 
| 217 | 
            +
                            mem = NeuralMemory(
         | 
| 218 | 
            +
                                dim = dim,
         | 
| 219 | 
            +
                                chunk_size = num_longterm_mem_tokens + segment_len,
         | 
| 220 | 
            +
                                **neural_memory_kwargs
         | 
| 221 | 
            +
                            )
         | 
| 222 | 
            +
             | 
| 218 223 | 
             
                            mem = init_hyper_conn(dim = dim, branch = mem)
         | 
| 219 224 |  | 
| 220 225 | 
             
                        self.neural_mem_layers.append(mem)
         | 
| @@ -266,7 +271,7 @@ class MemoryAsContextTransformer(Module): | |
| 266 271 | 
             
                    x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
         | 
| 267 272 |  | 
| 268 273 | 
             
                    mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
         | 
| 269 | 
            -
                    x =  | 
| 274 | 
            +
                    x, mem_ps = pack((x, mems), 'b * d')
         | 
| 270 275 |  | 
| 271 276 | 
             
                    x = inverse_segment(x)
         | 
| 272 277 |  | 
| @@ -283,21 +288,8 @@ class MemoryAsContextTransformer(Module): | |
| 283 288 | 
             
                    for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
         | 
| 284 289 |  | 
| 285 290 | 
             
                        if exists(maybe_neural_mem):
         | 
| 286 | 
            -
                             | 
| 287 | 
            -
             | 
| 288 | 
            -
                            x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
         | 
| 289 | 
            -
             | 
| 290 | 
            -
                            longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
         | 
| 291 | 
            -
             | 
| 292 | 
            -
                            longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
         | 
| 293 | 
            -
             | 
| 294 | 
            -
                            longterm_mems = maybe_neural_mem(longterm_mems)
         | 
| 295 | 
            -
             | 
| 296 | 
            -
                            longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
         | 
| 297 | 
            -
             | 
| 298 | 
            -
                            x = cat((longterm_mems, x), dim = -2)
         | 
| 291 | 
            +
                            x = maybe_neural_mem(x)
         | 
| 299 292 |  | 
| 300 | 
            -
                            x = inverse_segment(x)
         | 
| 301 293 |  | 
| 302 294 | 
             
                        x = attn(x)
         | 
| 303 295 |  | 
| @@ -309,7 +301,7 @@ class MemoryAsContextTransformer(Module): | |
| 309 301 |  | 
| 310 302 | 
             
                    x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
         | 
| 311 303 |  | 
| 312 | 
            -
                    x = x | 
| 304 | 
            +
                    x, _ = unpack(x, mem_ps, 'b * d')
         | 
| 313 305 |  | 
| 314 306 | 
             
                    x = inverse_segment(x)
         | 
| 315 307 |  | 
| @@ -27,9 +27,7 @@ n - sequence | |
| 27 27 | 
             
            d - feature dimension
         | 
| 28 28 | 
             
            c - intra-chunk
         | 
| 29 29 | 
             
            """
         | 
| 30 | 
            -
             | 
| 31 | 
            -
            # constants
         | 
| 32 | 
            -
             | 
| 30 | 
            +
            7
         | 
| 33 31 | 
             
            LinearNoBias = partial(Linear, bias = False)
         | 
| 34 32 |  | 
| 35 33 | 
             
            # functions
         | 
| @@ -132,7 +130,7 @@ class NeuralMemory(Module): | |
| 132 130 | 
             
                    max_grad_norm: float | None = None,
         | 
| 133 131 | 
             
                    use_accelerated_scan = False,
         | 
| 134 132 | 
             
                    default_mlp_kwargs: dict = dict(
         | 
| 135 | 
            -
                        depth =  | 
| 133 | 
            +
                        depth = 2
         | 
| 136 134 | 
             
                    )
         | 
| 137 135 | 
             
                ):
         | 
| 138 136 | 
             
                    super().__init__()
         | 
| @@ -390,7 +388,10 @@ class NeuralMemory(Module): | |
| 390 388 |  | 
| 391 389 | 
             
                    padding = next_seq_len - curtailed_seq_len
         | 
| 392 390 |  | 
| 393 | 
            -
                     | 
| 391 | 
            +
                    needs_pad = padding > 0
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                    if needs_pad:
         | 
| 394 | 
            +
                        seq = pad_at_dim(seq, (0, padding), dim = 1)
         | 
| 394 395 |  | 
| 395 396 | 
             
                    # the parameters of the memory model stores the memories of the key / values
         | 
| 396 397 | 
             
                    # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
         | 
| @@ -442,7 +443,9 @@ class NeuralMemory(Module): | |
| 442 443 | 
             
                    empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
         | 
| 443 444 | 
             
                    values = torch.cat((empty_memory_embeds, values), dim = -2)
         | 
| 444 445 |  | 
| 445 | 
            -
                     | 
| 446 | 
            +
                    if needs_pad:
         | 
| 447 | 
            +
                        values = values[:, :-padding]
         | 
| 448 | 
            +
             | 
| 446 449 | 
             
                    return values
         | 
| 447 450 |  | 
| 448 451 | 
             
                def forward(
         | 
| @@ -24,13 +24,13 @@ SHOULD_GENERATE = False | |
| 24 24 | 
             
            SEQ_LEN = 512
         | 
| 25 25 |  | 
| 26 26 | 
             
            PROJECT_NAME = 'titans-mac-transformer'
         | 
| 27 | 
            -
            WANDB_ONLINE =  | 
| 27 | 
            +
            WANDB_ONLINE = True # turn this on to pipe experiment to cloud
         | 
| 28 28 | 
             
            NEURAL_MEMORY_DEPTH = 2
         | 
| 29 29 | 
             
            NUM_PERSIST_MEM = 4
         | 
| 30 30 | 
             
            NUM_LONGTERM_MEM = 4
         | 
| 31 | 
            -
            NEURAL_MEM_LAYERS = ( | 
| 31 | 
            +
            NEURAL_MEM_LAYERS = (4,)
         | 
| 32 32 | 
             
            WINDOW_SIZE = 32
         | 
| 33 | 
            -
            RUN_NAME = 'mac - 4 longterm mems, layers ( | 
| 33 | 
            +
            RUN_NAME = 'mac - 4 longterm mems, layers (4,)'
         | 
| 34 34 |  | 
| 35 35 | 
             
            # wandb experiment tracker
         | 
| 36 36 |  | 
| @@ -63,10 +63,10 @@ model = MemoryAsContextTransformer( | |
| 63 63 | 
             
                num_longterm_mem_tokens = NUM_LONGTERM_MEM,
         | 
| 64 64 | 
             
                neural_memory_layers = NEURAL_MEM_LAYERS,
         | 
| 65 65 | 
             
                neural_memory_kwargs = dict(
         | 
| 66 | 
            +
                    dim_head = 64,
         | 
| 67 | 
            +
                    heads = 4,
         | 
| 66 68 | 
             
                    default_mlp_kwargs = dict(
         | 
| 67 69 | 
             
                        depth = NEURAL_MEMORY_DEPTH,
         | 
| 68 | 
            -
                        dim_head = 64,
         | 
| 69 | 
            -
                        heads = 4
         | 
| 70 70 | 
             
                    )
         | 
| 71 71 | 
             
                )
         | 
| 72 72 | 
             
            ).cuda()
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         |