titans-pytorch 0.0.16__tar.gz → 0.0.18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/PKG-INFO +1 -1
 - {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/pyproject.toml +1 -1
 - {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/titans_pytorch/titans.py +1 -1
 - titans_pytorch-0.0.18/titans_pytorch/titans_attn_memory.py +419 -0
 - {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/train.py +0 -1
 - {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/.github/workflows/python-publish.yml +0 -0
 - {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/.github/workflows/test.yaml +0 -0
 - {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/.gitignore +0 -0
 - {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/LICENSE +0 -0
 - {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/README.md +0 -0
 - {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/data/README.md +0 -0
 - {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/data/enwik8.gz +0 -0
 - {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/fig1.png +0 -0
 - {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/fig2.png +0 -0
 - {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/requirements.txt +0 -0
 - {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/tests/test_titans.py +0 -0
 - {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/titans_pytorch/__init__.py +0 -0
 - {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/titans_pytorch/associative_scan.py +0 -0
 
| 
         @@ -269,7 +269,7 @@ class NeuralMemory(Module): 
     | 
|
| 
       269 
269 
     | 
    
         
             
                            gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
         
     | 
| 
       270 
270 
     | 
    
         
             
                            inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
         
     | 
| 
       271 
271 
     | 
    
         | 
| 
       272 
     | 
    
         
            -
                            outputs = scan(gates, inputs)
         
     | 
| 
      
 272 
     | 
    
         
            +
                            outputs = scan(gates.contiguous(), inputs.contiguous())
         
     | 
| 
       273 
273 
     | 
    
         | 
| 
       274 
274 
     | 
    
         
             
                            outputs = outputs[..., :seq_len]
         
     | 
| 
       275 
275 
     | 
    
         
             
                            outputs = rearrange(outputs, 'b d n -> b n d')
         
     | 
| 
         @@ -0,0 +1,419 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from __future__ import annotations
         
     | 
| 
      
 2 
     | 
    
         
            +
            import math
         
     | 
| 
      
 3 
     | 
    
         
            +
            from functools import partial
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 6 
     | 
    
         
            +
            from torch import nn, Tensor
         
     | 
| 
      
 7 
     | 
    
         
            +
            import torch.nn.functional as F
         
     | 
| 
      
 8 
     | 
    
         
            +
            from torch.nn import Linear, Module
         
     | 
| 
      
 9 
     | 
    
         
            +
            from torch.func import functional_call, vmap, grad
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
            from tensordict import TensorDict
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            from titans_pytorch.associative_scan import (
         
     | 
| 
      
 14 
     | 
    
         
            +
                associative_scan,
         
     | 
| 
      
 15 
     | 
    
         
            +
                binary_operator,
         
     | 
| 
      
 16 
     | 
    
         
            +
                pad_at_dim
         
     | 
| 
      
 17 
     | 
    
         
            +
            )
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            import einx
         
     | 
| 
      
 20 
     | 
    
         
            +
            from einops import rearrange, pack, unpack
         
     | 
| 
      
 21 
     | 
    
         
            +
            from einops.layers.torch import Rearrange, Reduce
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
            """
         
     | 
| 
      
 24 
     | 
    
         
            +
            ein notation:
         
     | 
| 
      
 25 
     | 
    
         
            +
            b - batch
         
     | 
| 
      
 26 
     | 
    
         
            +
            n - sequence
         
     | 
| 
      
 27 
     | 
    
         
            +
            d - feature dimension
         
     | 
| 
      
 28 
     | 
    
         
            +
            c - intra-chunk
         
     | 
| 
      
 29 
     | 
    
         
            +
            """
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
            # constants
         
     | 
| 
      
 32 
     | 
    
         
            +
             
     | 
| 
      
 33 
     | 
    
         
            +
            LinearNoBias = partial(Linear, bias = False)
         
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
            # functions
         
     | 
| 
      
 36 
     | 
    
         
            +
             
     | 
| 
      
 37 
     | 
    
         
            +
            def exists(v):
         
     | 
| 
      
 38 
     | 
    
         
            +
                return v is not None
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
            def default(v, d):
         
     | 
| 
      
 41 
     | 
    
         
            +
                return v if exists(v) else d
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
            def round_down_multiple(seq, mult):
         
     | 
| 
      
 44 
     | 
    
         
            +
                return seq // mult * mult
         
     | 
| 
      
 45 
     | 
    
         
            +
             
     | 
| 
      
 46 
     | 
    
         
            +
            def round_up_multiple(seq, mult):
         
     | 
| 
      
 47 
     | 
    
         
            +
                return math.ceil(seq / mult) * mult
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
            def pack_one_with_inverse(t, pattern):
         
     | 
| 
      
 50 
     | 
    
         
            +
                packed, packed_shape = pack([t], pattern)
         
     | 
| 
      
 51 
     | 
    
         
            +
             
     | 
| 
      
 52 
     | 
    
         
            +
                def inverse(out, inv_pattern = None):
         
     | 
| 
      
 53 
     | 
    
         
            +
                    inv_pattern = default(inv_pattern, pattern)
         
     | 
| 
      
 54 
     | 
    
         
            +
                    return unpack(out, packed_shape, inv_pattern)[0]
         
     | 
| 
      
 55 
     | 
    
         
            +
             
     | 
| 
      
 56 
     | 
    
         
            +
                return packed, inverse
         
     | 
| 
      
 57 
     | 
    
         
            +
             
     | 
| 
      
 58 
     | 
    
         
            +
            # classes
         
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
            # improvised attention as memory module
         
     | 
| 
      
 61 
     | 
    
         
            +
            # todo - expand if see signal in experiments (update: not seeing it)
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
            class MemoryAttention(Module):
         
     | 
| 
      
 64 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 65 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 66 
     | 
    
         
            +
                    dim
         
     | 
| 
      
 67 
     | 
    
         
            +
                ):
         
     | 
| 
      
 68 
     | 
    
         
            +
                    super().__init__()
         
     | 
| 
      
 69 
     | 
    
         
            +
                    self.weights = nn.ParameterList([
         
     | 
| 
      
 70 
     | 
    
         
            +
                        nn.Parameter(torch.randn(dim, dim)), # queries
         
     | 
| 
      
 71 
     | 
    
         
            +
                        nn.Parameter(torch.randn(dim, dim)), # keys
         
     | 
| 
      
 72 
     | 
    
         
            +
                        nn.Parameter(torch.randn(dim, dim)), # values weight 1
         
     | 
| 
      
 73 
     | 
    
         
            +
                        nn.Parameter(torch.randn(dim, dim)), # values weight 2
         
     | 
| 
      
 74 
     | 
    
         
            +
                    ])
         
     | 
| 
      
 75 
     | 
    
         
            +
             
     | 
| 
      
 76 
     | 
    
         
            +
                def forward(self, x):
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
                    assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
                    wq, wk, wv1, wv2 = self.weights
         
     | 
| 
      
 81 
     | 
    
         
            +
             
     | 
| 
      
 82 
     | 
    
         
            +
                    q = x @ wq
         
     | 
| 
      
 83 
     | 
    
         
            +
                    k = x @ wk
         
     | 
| 
      
 84 
     | 
    
         
            +
                    v = x @ wv1
         
     | 
| 
      
 85 
     | 
    
         
            +
             
     | 
| 
      
 86 
     | 
    
         
            +
                    hidden = F.scaled_dot_product_attention(
         
     | 
| 
      
 87 
     | 
    
         
            +
                        q, k, v,
         
     | 
| 
      
 88 
     | 
    
         
            +
                        is_causal = True
         
     | 
| 
      
 89 
     | 
    
         
            +
                    )
         
     | 
| 
      
 90 
     | 
    
         
            +
             
     | 
| 
      
 91 
     | 
    
         
            +
                    return F.silu(hidden) @ wv2
         
     | 
| 
      
 92 
     | 
    
         
            +
             
     | 
| 
      
 93 
     | 
    
         
            +
            # main neural memory
         
     | 
| 
      
 94 
     | 
    
         
            +
             
     | 
| 
      
 95 
     | 
    
         
            +
            def default_loss_fn(pred, target):
         
     | 
| 
      
 96 
     | 
    
         
            +
                return (pred - target).pow(2).mean(dim = -1).sum()
         
     | 
| 
      
 97 
     | 
    
         
            +
             
     | 
| 
      
 98 
     | 
    
         
            +
            class NeuralMemory(Module):
         
     | 
| 
      
 99 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 100 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 101 
     | 
    
         
            +
                    dim,
         
     | 
| 
      
 102 
     | 
    
         
            +
                    chunk_size = 1,
         
     | 
| 
      
 103 
     | 
    
         
            +
                    dim_head = None,
         
     | 
| 
      
 104 
     | 
    
         
            +
                    heads = 1,
         
     | 
| 
      
 105 
     | 
    
         
            +
                    model: MemoryAttention | None = None,
         
     | 
| 
      
 106 
     | 
    
         
            +
                    store_memory_loss_fn: Callable = default_loss_fn,
         
     | 
| 
      
 107 
     | 
    
         
            +
                    pre_rmsnorm = True,
         
     | 
| 
      
 108 
     | 
    
         
            +
                    post_rmsnorm = True,
         
     | 
| 
      
 109 
     | 
    
         
            +
                    use_accelerated_scan = False,
         
     | 
| 
      
 110 
     | 
    
         
            +
                    default_model_kwargs: dict = dict()
         
     | 
| 
      
 111 
     | 
    
         
            +
                ):
         
     | 
| 
      
 112 
     | 
    
         
            +
                    super().__init__()
         
     | 
| 
      
 113 
     | 
    
         
            +
             
     | 
| 
      
 114 
     | 
    
         
            +
                    # norms
         
     | 
| 
      
 115 
     | 
    
         
            +
             
     | 
| 
      
 116 
     | 
    
         
            +
                    self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
         
     | 
| 
      
 117 
     | 
    
         
            +
                    self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
         
     | 
| 
      
 118 
     | 
    
         
            +
             
     | 
| 
      
 119 
     | 
    
         
            +
                    self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
         
     | 
| 
      
 120 
     | 
    
         
            +
             
     | 
| 
      
 121 
     | 
    
         
            +
                    # maybe multi-headed
         
     | 
| 
      
 122 
     | 
    
         
            +
             
     | 
| 
      
 123 
     | 
    
         
            +
                    dim_head = default(dim_head, dim)
         
     | 
| 
      
 124 
     | 
    
         
            +
                    dim_inner = dim_head * heads
         
     | 
| 
      
 125 
     | 
    
         
            +
             
     | 
| 
      
 126 
     | 
    
         
            +
                    self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
         
     | 
| 
      
 127 
     | 
    
         
            +
                    self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
         
     | 
| 
      
 128 
     | 
    
         
            +
                    self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
         
     | 
| 
      
 129 
     | 
    
         
            +
             
     | 
| 
      
 130 
     | 
    
         
            +
                    # memory mlp
         
     | 
| 
      
 131 
     | 
    
         
            +
             
     | 
| 
      
 132 
     | 
    
         
            +
                    if not exists(model):
         
     | 
| 
      
 133 
     | 
    
         
            +
                        model = MemoryAttention(dim_head, **default_model_kwargs)
         
     | 
| 
      
 134 
     | 
    
         
            +
             
     | 
| 
      
 135 
     | 
    
         
            +
                    assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
         
     | 
| 
      
 136 
     | 
    
         
            +
             
     | 
| 
      
 137 
     | 
    
         
            +
                    # the memory is the weights of the model
         
     | 
| 
      
 138 
     | 
    
         
            +
             
     | 
| 
      
 139 
     | 
    
         
            +
                    self.memory_model = model
         
     | 
| 
      
 140 
     | 
    
         
            +
             
     | 
| 
      
 141 
     | 
    
         
            +
                    # the chunk size within the paper where adaptive step, momentum, weight decay are shared
         
     | 
| 
      
 142 
     | 
    
         
            +
             
     | 
| 
      
 143 
     | 
    
         
            +
                    self.chunk_size = chunk_size
         
     | 
| 
      
 144 
     | 
    
         
            +
             
     | 
| 
      
 145 
     | 
    
         
            +
                    # prepare function for per sample gradients from model above, using torch.func
         
     | 
| 
      
 146 
     | 
    
         
            +
             
     | 
| 
      
 147 
     | 
    
         
            +
                    def forward_and_loss(params, inputs, target):
         
     | 
| 
      
 148 
     | 
    
         
            +
                        pred = functional_call(self.memory_model, params, inputs)
         
     | 
| 
      
 149 
     | 
    
         
            +
                        loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
         
     | 
| 
      
 150 
     | 
    
         
            +
                        return loss
         
     | 
| 
      
 151 
     | 
    
         
            +
             
     | 
| 
      
 152 
     | 
    
         
            +
                    self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
         
     | 
| 
      
 153 
     | 
    
         
            +
             
     | 
| 
      
 154 
     | 
    
         
            +
                    # queries for retrieving from the model
         
     | 
| 
      
 155 
     | 
    
         
            +
             
     | 
| 
      
 156 
     | 
    
         
            +
                    self.to_queries = LinearNoBias(dim, dim_inner)
         
     | 
| 
      
 157 
     | 
    
         
            +
             
     | 
| 
      
 158 
     | 
    
         
            +
                    # keys and values for storing to the model
         
     | 
| 
      
 159 
     | 
    
         
            +
             
     | 
| 
      
 160 
     | 
    
         
            +
                    self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
         
     | 
| 
      
 161 
     | 
    
         
            +
                    self.store_memory_loss_fn = store_memory_loss_fn
         
     | 
| 
      
 162 
     | 
    
         
            +
             
     | 
| 
      
 163 
     | 
    
         
            +
                    # learned adaptive learning rate and momentum
         
     | 
| 
      
 164 
     | 
    
         
            +
                    # todo - explore mlp layerwise learned lr / momentum
         
     | 
| 
      
 165 
     | 
    
         
            +
             
     | 
| 
      
 166 
     | 
    
         
            +
                    self.to_momentum = nn.Sequential(
         
     | 
| 
      
 167 
     | 
    
         
            +
                        Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
         
     | 
| 
      
 168 
     | 
    
         
            +
                        LinearNoBias(dim, heads),
         
     | 
| 
      
 169 
     | 
    
         
            +
                        Rearrange('b n h -> (b h) n 1')
         
     | 
| 
      
 170 
     | 
    
         
            +
                    )
         
     | 
| 
      
 171 
     | 
    
         
            +
             
     | 
| 
      
 172 
     | 
    
         
            +
                    self.to_adaptive_step = nn.Sequential(
         
     | 
| 
      
 173 
     | 
    
         
            +
                        Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
         
     | 
| 
      
 174 
     | 
    
         
            +
                        LinearNoBias(dim, heads),
         
     | 
| 
      
 175 
     | 
    
         
            +
                        Rearrange('b n h -> (b h) n')
         
     | 
| 
      
 176 
     | 
    
         
            +
                    )
         
     | 
| 
      
 177 
     | 
    
         
            +
             
     | 
| 
      
 178 
     | 
    
         
            +
                    # weight decay factor
         
     | 
| 
      
 179 
     | 
    
         
            +
             
     | 
| 
      
 180 
     | 
    
         
            +
                    self.to_decay_factor = nn.Sequential(
         
     | 
| 
      
 181 
     | 
    
         
            +
                        Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
         
     | 
| 
      
 182 
     | 
    
         
            +
                        LinearNoBias(dim, heads),
         
     | 
| 
      
 183 
     | 
    
         
            +
                        Rearrange('b n h -> (b h) n 1')
         
     | 
| 
      
 184 
     | 
    
         
            +
                    )
         
     | 
| 
      
 185 
     | 
    
         
            +
             
     | 
| 
      
 186 
     | 
    
         
            +
                    # maybe use accelerated scan
         
     | 
| 
      
 187 
     | 
    
         
            +
             
     | 
| 
      
 188 
     | 
    
         
            +
                    self.use_accelerated_scan = use_accelerated_scan
         
     | 
| 
      
 189 
     | 
    
         
            +
             
     | 
| 
      
 190 
     | 
    
         
            +
                def init_weights_and_momentum(self):
         
     | 
| 
      
 191 
     | 
    
         
            +
                    params = TensorDict(dict(self.memory_model.named_parameters()))
         
     | 
| 
      
 192 
     | 
    
         
            +
             
     | 
| 
      
 193 
     | 
    
         
            +
                    init_weights = params.clone().zero_()
         
     | 
| 
      
 194 
     | 
    
         
            +
                    init_momentum = params.clone().zero_()
         
     | 
| 
      
 195 
     | 
    
         
            +
             
     | 
| 
      
 196 
     | 
    
         
            +
                    return init_weights, init_momentum
         
     | 
| 
      
 197 
     | 
    
         
            +
             
     | 
| 
      
 198 
     | 
    
         
            +
                def store_memories(
         
     | 
| 
      
 199 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 200 
     | 
    
         
            +
                    seq,
         
     | 
| 
      
 201 
     | 
    
         
            +
                    past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
         
     | 
| 
      
 202 
     | 
    
         
            +
                ):
         
     | 
| 
      
 203 
     | 
    
         
            +
             
     | 
| 
      
 204 
     | 
    
         
            +
                    seq = self.store_norm(seq)
         
     | 
| 
      
 205 
     | 
    
         
            +
             
     | 
| 
      
 206 
     | 
    
         
            +
                    # curtail sequence by multiple of the chunk size
         
     | 
| 
      
 207 
     | 
    
         
            +
                    # only a complete chunk of the sequence provides the memory for the next chunk
         
     | 
| 
      
 208 
     | 
    
         
            +
             
     | 
| 
      
 209 
     | 
    
         
            +
                    seq_len, chunk_size = seq.shape[-2], self.chunk_size
         
     | 
| 
      
 210 
     | 
    
         
            +
                    round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
         
     | 
| 
      
 211 
     | 
    
         
            +
             
     | 
| 
      
 212 
     | 
    
         
            +
                    seq = seq[:, :round_down_seq_len]
         
     | 
| 
      
 213 
     | 
    
         
            +
             
     | 
| 
      
 214 
     | 
    
         
            +
                    # curr weights + past weights, in the case that the initial weights are learned
         
     | 
| 
      
 215 
     | 
    
         
            +
             
     | 
| 
      
 216 
     | 
    
         
            +
                    curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
         
     | 
| 
      
 217 
     | 
    
         
            +
             
     | 
| 
      
 218 
     | 
    
         
            +
                    past_state = tuple(TensorDict(d) for d in past_state)
         
     | 
| 
      
 219 
     | 
    
         
            +
                    past_weights, past_momentum = past_state
         
     | 
| 
      
 220 
     | 
    
         
            +
             
     | 
| 
      
 221 
     | 
    
         
            +
                    curr_weights = curr_weights + past_weights
         
     | 
| 
      
 222 
     | 
    
         
            +
             
     | 
| 
      
 223 
     | 
    
         
            +
                    # pack batch and sequence dimension
         
     | 
| 
      
 224 
     | 
    
         
            +
             
     | 
| 
      
 225 
     | 
    
         
            +
                    adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
         
     | 
| 
      
 226 
     | 
    
         
            +
             
     | 
| 
      
 227 
     | 
    
         
            +
                    adaptive_momentum = self.to_momentum(seq).sigmoid()
         
     | 
| 
      
 228 
     | 
    
         
            +
                    decay_factor = self.to_decay_factor(seq).sigmoid()
         
     | 
| 
      
 229 
     | 
    
         
            +
             
     | 
| 
      
 230 
     | 
    
         
            +
                    # keys and values
         
     | 
| 
      
 231 
     | 
    
         
            +
             
     | 
| 
      
 232 
     | 
    
         
            +
                    keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
         
     | 
| 
      
 233 
     | 
    
         
            +
             
     | 
| 
      
 234 
     | 
    
         
            +
                    # maybe multi head
         
     | 
| 
      
 235 
     | 
    
         
            +
             
     | 
| 
      
 236 
     | 
    
         
            +
                    keys, values = map(self.split_heads, (keys, values))
         
     | 
| 
      
 237 
     | 
    
         
            +
             
     | 
| 
      
 238 
     | 
    
         
            +
                    batch = keys.shape[0]
         
     | 
| 
      
 239 
     | 
    
         
            +
             
     | 
| 
      
 240 
     | 
    
         
            +
                    # take care of chunking
         
     | 
| 
      
 241 
     | 
    
         
            +
             
     | 
| 
      
 242 
     | 
    
         
            +
                    keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
         
     | 
| 
      
 243 
     | 
    
         
            +
             
     | 
| 
      
 244 
     | 
    
         
            +
                    # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
         
     | 
| 
      
 245 
     | 
    
         
            +
             
     | 
| 
      
 246 
     | 
    
         
            +
                    grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
         
     | 
| 
      
 247 
     | 
    
         
            +
             
     | 
| 
      
 248 
     | 
    
         
            +
                    grads = TensorDict(grads)
         
     | 
| 
      
 249 
     | 
    
         
            +
             
     | 
| 
      
 250 
     | 
    
         
            +
                    # restore batch and sequence dimension
         
     | 
| 
      
 251 
     | 
    
         
            +
             
     | 
| 
      
 252 
     | 
    
         
            +
                    grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
         
     | 
| 
      
 253 
     | 
    
         
            +
             
     | 
| 
      
 254 
     | 
    
         
            +
                    # multiply gradients with learned adaptive step size
         
     | 
| 
      
 255 
     | 
    
         
            +
             
     | 
| 
      
 256 
     | 
    
         
            +
                    surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
         
     | 
| 
      
 257 
     | 
    
         
            +
             
     | 
| 
      
 258 
     | 
    
         
            +
                    # determine scan function
         
     | 
| 
      
 259 
     | 
    
         
            +
             
     | 
| 
      
 260 
     | 
    
         
            +
                    def default_associative_scan(gates, inputs):
         
     | 
| 
      
 261 
     | 
    
         
            +
                        _, outputs = associative_scan(binary_operator, (gates, inputs))
         
     | 
| 
      
 262 
     | 
    
         
            +
                        return outputs
         
     | 
| 
      
 263 
     | 
    
         
            +
             
     | 
| 
      
 264 
     | 
    
         
            +
                    if self.use_accelerated_scan:
         
     | 
| 
      
 265 
     | 
    
         
            +
                        from accelerated_scan.triton import scan as triton_scan
         
     | 
| 
      
 266 
     | 
    
         
            +
                        from accelerated_scan.warp import scan as warp_scan
         
     | 
| 
      
 267 
     | 
    
         
            +
             
     | 
| 
      
 268 
     | 
    
         
            +
                        scan = triton_scan if seq.is_cuda else warp_scan
         
     | 
| 
      
 269 
     | 
    
         
            +
             
     | 
| 
      
 270 
     | 
    
         
            +
                        def accelerate_scan_fn(gates, inputs):
         
     | 
| 
      
 271 
     | 
    
         
            +
                            gates = gates.expand_as(inputs)
         
     | 
| 
      
 272 
     | 
    
         
            +
                            gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
         
     | 
| 
      
 273 
     | 
    
         
            +
             
     | 
| 
      
 274 
     | 
    
         
            +
                            seq_len = gates.shape[-1]
         
     | 
| 
      
 275 
     | 
    
         
            +
                            next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
         
     | 
| 
      
 276 
     | 
    
         
            +
             
     | 
| 
      
 277 
     | 
    
         
            +
                            gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
         
     | 
| 
      
 278 
     | 
    
         
            +
                            inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
         
     | 
| 
      
 279 
     | 
    
         
            +
             
     | 
| 
      
 280 
     | 
    
         
            +
                            outputs = scan(gates, inputs)
         
     | 
| 
      
 281 
     | 
    
         
            +
             
     | 
| 
      
 282 
     | 
    
         
            +
                            outputs = outputs[..., :seq_len]
         
     | 
| 
      
 283 
     | 
    
         
            +
                            outputs = rearrange(outputs, 'b d n -> b n d')
         
     | 
| 
      
 284 
     | 
    
         
            +
                            return outputs
         
     | 
| 
      
 285 
     | 
    
         
            +
             
     | 
| 
      
 286 
     | 
    
         
            +
                        scan_fn = accelerate_scan_fn
         
     | 
| 
      
 287 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 288 
     | 
    
         
            +
                        scan_fn = default_associative_scan
         
     | 
| 
      
 289 
     | 
    
         
            +
             
     | 
| 
      
 290 
     | 
    
         
            +
                    # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
         
     | 
| 
      
 291 
     | 
    
         
            +
             
     | 
| 
      
 292 
     | 
    
         
            +
                    next_momentum = TensorDict()
         
     | 
| 
      
 293 
     | 
    
         
            +
                    updates = TensorDict()
         
     | 
| 
      
 294 
     | 
    
         
            +
             
     | 
| 
      
 295 
     | 
    
         
            +
                    for param_name, surprise in surprises.items():
         
     | 
| 
      
 296 
     | 
    
         
            +
             
     | 
| 
      
 297 
     | 
    
         
            +
                        surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
         
     | 
| 
      
 298 
     | 
    
         
            +
             
     | 
| 
      
 299 
     | 
    
         
            +
                        # derive momentum with associative scan - eq (10)
         
     | 
| 
      
 300 
     | 
    
         
            +
             
     | 
| 
      
 301 
     | 
    
         
            +
                        momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
         
     | 
| 
      
 302 
     | 
    
         
            +
             
     | 
| 
      
 303 
     | 
    
         
            +
                        # use associative scan again for learned forgetting (weight decay) - eq (13)
         
     | 
| 
      
 304 
     | 
    
         
            +
             
     | 
| 
      
 305 
     | 
    
         
            +
                        update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
         
     | 
| 
      
 306 
     | 
    
         
            +
             
     | 
| 
      
 307 
     | 
    
         
            +
                        updates[param_name] = inverse_pack(update)
         
     | 
| 
      
 308 
     | 
    
         
            +
                        next_momentum[param_name] = inverse_pack(momentum)
         
     | 
| 
      
 309 
     | 
    
         
            +
             
     | 
| 
      
 310 
     | 
    
         
            +
                    # compute the next weight per batch
         
     | 
| 
      
 311 
     | 
    
         
            +
             
     | 
| 
      
 312 
     | 
    
         
            +
                    last_update = updates.apply(lambda t: t[:, -1])
         
     | 
| 
      
 313 
     | 
    
         
            +
             
     | 
| 
      
 314 
     | 
    
         
            +
                    next_state = (curr_weights + last_update, next_momentum)
         
     | 
| 
      
 315 
     | 
    
         
            +
             
     | 
| 
      
 316 
     | 
    
         
            +
                    return updates, next_state
         
     | 
| 
      
 317 
     | 
    
         
            +
             
     | 
| 
      
 318 
     | 
    
         
            +
                def retrieve_memories(
         
     | 
| 
      
 319 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 320 
     | 
    
         
            +
                    seq,
         
     | 
| 
      
 321 
     | 
    
         
            +
                    past_weights: dict[str, Tensor] | None = None,
         
     | 
| 
      
 322 
     | 
    
         
            +
                ):
         
     | 
| 
      
 323 
     | 
    
         
            +
                    chunk_size = self.chunk_size
         
     | 
| 
      
 324 
     | 
    
         
            +
                    seq_len = seq.shape[1]
         
     | 
| 
      
 325 
     | 
    
         
            +
             
     | 
| 
      
 326 
     | 
    
         
            +
                    seq = self.retrieve_norm(seq)
         
     | 
| 
      
 327 
     | 
    
         
            +
             
     | 
| 
      
 328 
     | 
    
         
            +
                    assert seq_len > chunk_size
         
     | 
| 
      
 329 
     | 
    
         
            +
             
     | 
| 
      
 330 
     | 
    
         
            +
                    seq = seq[:, chunk_size:]
         
     | 
| 
      
 331 
     | 
    
         
            +
                    curtailed_seq_len = seq.shape[-2]
         
     | 
| 
      
 332 
     | 
    
         
            +
             
     | 
| 
      
 333 
     | 
    
         
            +
                    next_seq_len = round_up_multiple(curtailed_seq_len + 1, chunk_size)
         
     | 
| 
      
 334 
     | 
    
         
            +
             
     | 
| 
      
 335 
     | 
    
         
            +
                    padding = next_seq_len - curtailed_seq_len
         
     | 
| 
      
 336 
     | 
    
         
            +
             
     | 
| 
      
 337 
     | 
    
         
            +
                    seq = pad_at_dim(seq, (0, padding), dim = 1)
         
     | 
| 
      
 338 
     | 
    
         
            +
             
     | 
| 
      
 339 
     | 
    
         
            +
                    # the parameters of the memory model stores the memories of the key / values
         
     | 
| 
      
 340 
     | 
    
         
            +
                    # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
         
     | 
| 
      
 341 
     | 
    
         
            +
             
     | 
| 
      
 342 
     | 
    
         
            +
                    curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
         
     | 
| 
      
 343 
     | 
    
         
            +
             
     | 
| 
      
 344 
     | 
    
         
            +
                    if exists(past_weights):
         
     | 
| 
      
 345 
     | 
    
         
            +
                        past_weights = TensorDict(past_weights)
         
     | 
| 
      
 346 
     | 
    
         
            +
                        assert past_weights.keys() == curr_weights.keys()
         
     | 
| 
      
 347 
     | 
    
         
            +
             
     | 
| 
      
 348 
     | 
    
         
            +
                        curr_weights = curr_weights + past_weights
         
     | 
| 
      
 349 
     | 
    
         
            +
             
     | 
| 
      
 350 
     | 
    
         
            +
                    # sequence Float['b n d'] to queries
         
     | 
| 
      
 351 
     | 
    
         
            +
             
     | 
| 
      
 352 
     | 
    
         
            +
                    queries = self.to_queries(seq)
         
     | 
| 
      
 353 
     | 
    
         
            +
             
     | 
| 
      
 354 
     | 
    
         
            +
                    # maybe multihead
         
     | 
| 
      
 355 
     | 
    
         
            +
             
     | 
| 
      
 356 
     | 
    
         
            +
                    queries = self.split_heads(queries)
         
     | 
| 
      
 357 
     | 
    
         
            +
             
     | 
| 
      
 358 
     | 
    
         
            +
                    batch = queries.shape[0]
         
     | 
| 
      
 359 
     | 
    
         
            +
             
     | 
| 
      
 360 
     | 
    
         
            +
                    # fetch values from memory model
         
     | 
| 
      
 361 
     | 
    
         
            +
             
     | 
| 
      
 362 
     | 
    
         
            +
                    curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
         
     | 
| 
      
 363 
     | 
    
         
            +
                    queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
         
     | 
| 
      
 364 
     | 
    
         
            +
             
     | 
| 
      
 365 
     | 
    
         
            +
                    # forward functional call
         
     | 
| 
      
 366 
     | 
    
         
            +
             
     | 
| 
      
 367 
     | 
    
         
            +
                    values = functional_call(self.memory_model, dict(curr_weights), queries)
         
     | 
| 
      
 368 
     | 
    
         
            +
             
     | 
| 
      
 369 
     | 
    
         
            +
                    # reconstitute batch dimension
         
     | 
| 
      
 370 
     | 
    
         
            +
             
     | 
| 
      
 371 
     | 
    
         
            +
                    values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
         
     | 
| 
      
 372 
     | 
    
         
            +
             
     | 
| 
      
 373 
     | 
    
         
            +
                    # maybe merge heads and combine
         
     | 
| 
      
 374 
     | 
    
         
            +
             
     | 
| 
      
 375 
     | 
    
         
            +
                    values = self.merge_heads(values)
         
     | 
| 
      
 376 
     | 
    
         
            +
             
     | 
| 
      
 377 
     | 
    
         
            +
                    values = self.combine_heads(values)
         
     | 
| 
      
 378 
     | 
    
         
            +
             
     | 
| 
      
 379 
     | 
    
         
            +
                    # post norm, somehow could not stabilize this without it, not in paper
         
     | 
| 
      
 380 
     | 
    
         
            +
             
     | 
| 
      
 381 
     | 
    
         
            +
                    values = self.post_rmsnorm(values)
         
     | 
| 
      
 382 
     | 
    
         
            +
             
     | 
| 
      
 383 
     | 
    
         
            +
                    # restore
         
     | 
| 
      
 384 
     | 
    
         
            +
             
     | 
| 
      
 385 
     | 
    
         
            +
                    values = pad_at_dim(values, (chunk_size, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
         
     | 
| 
      
 386 
     | 
    
         
            +
                    values = values[:, :-padding]
         
     | 
| 
      
 387 
     | 
    
         
            +
             
     | 
| 
      
 388 
     | 
    
         
            +
                    return values
         
     | 
| 
      
 389 
     | 
    
         
            +
             
     | 
| 
      
 390 
     | 
    
         
            +
                def forward(
         
     | 
| 
      
 391 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 392 
     | 
    
         
            +
                    seq,
         
     | 
| 
      
 393 
     | 
    
         
            +
                    store_seq = None,
         
     | 
| 
      
 394 
     | 
    
         
            +
                    past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
         
     | 
| 
      
 395 
     | 
    
         
            +
                    return_next_memories = False
         
     | 
| 
      
 396 
     | 
    
         
            +
                ):
         
     | 
| 
      
 397 
     | 
    
         
            +
                    batch, seq_len = seq.shape[:2]
         
     | 
| 
      
 398 
     | 
    
         
            +
             
     | 
| 
      
 399 
     | 
    
         
            +
                    if seq_len <= self.chunk_size:
         
     | 
| 
      
 400 
     | 
    
         
            +
                        return torch.zeros_like(seq)
         
     | 
| 
      
 401 
     | 
    
         
            +
             
     | 
| 
      
 402 
     | 
    
         
            +
                    if exists(past_state):
         
     | 
| 
      
 403 
     | 
    
         
            +
                        past_state = tuple(TensorDict(d) for d in past_state)
         
     | 
| 
      
 404 
     | 
    
         
            +
             
     | 
| 
      
 405 
     | 
    
         
            +
                    if not exists(past_state):
         
     | 
| 
      
 406 
     | 
    
         
            +
                        past_state = self.init_weights_and_momentum()
         
     | 
| 
      
 407 
     | 
    
         
            +
             
     | 
| 
      
 408 
     | 
    
         
            +
                    store_seq = default(store_seq, seq)
         
     | 
| 
      
 409 
     | 
    
         
            +
             
     | 
| 
      
 410 
     | 
    
         
            +
                    updates, next_memories = self.store_memories(store_seq, past_state)
         
     | 
| 
      
 411 
     | 
    
         
            +
             
     | 
| 
      
 412 
     | 
    
         
            +
                    past_weights, _ = past_state
         
     | 
| 
      
 413 
     | 
    
         
            +
             
     | 
| 
      
 414 
     | 
    
         
            +
                    retrieved = self.retrieve_memories(seq, past_weights + updates)
         
     | 
| 
      
 415 
     | 
    
         
            +
             
     | 
| 
      
 416 
     | 
    
         
            +
                    if not return_next_memories:
         
     | 
| 
      
 417 
     | 
    
         
            +
                        return retrieved
         
     | 
| 
      
 418 
     | 
    
         
            +
             
     | 
| 
      
 419 
     | 
    
         
            +
                    return retrieved, next_memories
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |