titans-pytorch 0.0.19__tar.gz → 0.0.21__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/PKG-INFO +1 -1
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/pyproject.toml +1 -1
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/titans_pytorch/titans.py +36 -1
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/train.py +3 -3
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/.github/workflows/python-publish.yml +0 -0
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/.github/workflows/test.yaml +0 -0
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/.gitignore +0 -0
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/LICENSE +0 -0
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/README.md +0 -0
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/data/README.md +0 -0
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/data/enwik8.gz +0 -0
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/fig1.png +0 -0
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/fig2.png +0 -0
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/requirements.txt +0 -0
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/tests/test_titans.py +0 -0
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/titans_pytorch/__init__.py +0 -0
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/titans_pytorch/associative_scan.py +0 -0
 - {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/titans_pytorch/titans_attn_memory.py +0 -0
 
| 
         @@ -40,6 +40,9 @@ def exists(v): 
     | 
|
| 
       40 
40 
     | 
    
         
             
            def default(v, d):
         
     | 
| 
       41 
41 
     | 
    
         
             
                return v if exists(v) else d
         
     | 
| 
       42 
42 
     | 
    
         | 
| 
      
 43 
     | 
    
         
            +
            def identity(t):
         
     | 
| 
      
 44 
     | 
    
         
            +
                return t
         
     | 
| 
      
 45 
     | 
    
         
            +
             
     | 
| 
       43 
46 
     | 
    
         
             
            def round_down_multiple(seq, mult):
         
     | 
| 
       44 
47 
     | 
    
         
             
                return seq // mult * mult
         
     | 
| 
       45 
48 
     | 
    
         | 
| 
         @@ -55,6 +58,21 @@ def pack_one_with_inverse(t, pattern): 
     | 
|
| 
       55 
58 
     | 
    
         | 
| 
       56 
59 
     | 
    
         
             
                return packed, inverse
         
     | 
| 
       57 
60 
     | 
    
         | 
| 
      
 61 
     | 
    
         
            +
            # softclamping gradients
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
            def softclamp_max(t, max_value):
         
     | 
| 
      
 64 
     | 
    
         
            +
                half_max_value = max_value / 2
         
     | 
| 
      
 65 
     | 
    
         
            +
                return ((t / half_max_value).tanh() * half_max_value) + half_max_value
         
     | 
| 
      
 66 
     | 
    
         
            +
             
     | 
| 
      
 67 
     | 
    
         
            +
            def softclamp_grad_norm(t, max_value):
         
     | 
| 
      
 68 
     | 
    
         
            +
                t, inverse = pack_one_with_inverse(t, 'bn *')
         
     | 
| 
      
 69 
     | 
    
         
            +
             
     | 
| 
      
 70 
     | 
    
         
            +
                norm = t.norm(dim = -1, keepdim = True)
         
     | 
| 
      
 71 
     | 
    
         
            +
                clamped_norm = softclamp_max(norm, max_value)
         
     | 
| 
      
 72 
     | 
    
         
            +
             
     | 
| 
      
 73 
     | 
    
         
            +
                t = t * (clamped_norm / norm)
         
     | 
| 
      
 74 
     | 
    
         
            +
                return inverse(t)
         
     | 
| 
      
 75 
     | 
    
         
            +
             
     | 
| 
       58 
76 
     | 
    
         
             
            # classes
         
     | 
| 
       59 
77 
     | 
    
         | 
| 
       60 
78 
     | 
    
         
             
            class MemoryMLP(Module):
         
     | 
| 
         @@ -82,6 +100,9 @@ class MemoryMLP(Module): 
     | 
|
| 
       82 
100 
     | 
    
         | 
| 
       83 
101 
     | 
    
         
             
            # main neural memory
         
     | 
| 
       84 
102 
     | 
    
         | 
| 
      
 103 
     | 
    
         
            +
            def default_adaptive_step_transform(adaptive_step):
         
     | 
| 
      
 104 
     | 
    
         
            +
                return torch.exp(adaptive_step.sigmoid() * -15) # from 1. - 1e-7
         
     | 
| 
      
 105 
     | 
    
         
            +
             
     | 
| 
       85 
106 
     | 
    
         
             
            def default_loss_fn(pred, target):
         
     | 
| 
       86 
107 
     | 
    
         
             
                return (pred - target).pow(2).mean(dim = -1).sum()
         
     | 
| 
       87 
108 
     | 
    
         | 
| 
         @@ -94,8 +115,10 @@ class NeuralMemory(Module): 
     | 
|
| 
       94 
115 
     | 
    
         
             
                    heads = 1,
         
     | 
| 
       95 
116 
     | 
    
         
             
                    model: Module | None = None,
         
     | 
| 
       96 
117 
     | 
    
         
             
                    store_memory_loss_fn: Callable = default_loss_fn,
         
     | 
| 
      
 118 
     | 
    
         
            +
                    adaptive_step_transform: Callable = default_adaptive_step_transform,
         
     | 
| 
       97 
119 
     | 
    
         
             
                    pre_rmsnorm = True,
         
     | 
| 
       98 
120 
     | 
    
         
             
                    post_rmsnorm = True,
         
     | 
| 
      
 121 
     | 
    
         
            +
                    max_grad_norm: float | None = None,
         
     | 
| 
       99 
122 
     | 
    
         
             
                    use_accelerated_scan = False,
         
     | 
| 
       100 
123 
     | 
    
         
             
                    default_mlp_kwargs: dict = dict(
         
     | 
| 
       101 
124 
     | 
    
         
             
                        depth = 4
         
     | 
| 
         @@ -172,6 +195,12 @@ class NeuralMemory(Module): 
     | 
|
| 
       172 
195 
     | 
    
         
             
                        Rearrange('b n h -> (b h) n')
         
     | 
| 
       173 
196 
     | 
    
         
             
                    )
         
     | 
| 
       174 
197 
     | 
    
         | 
| 
      
 198 
     | 
    
         
            +
                    self.adaptive_step_transform = adaptive_step_transform
         
     | 
| 
      
 199 
     | 
    
         
            +
             
     | 
| 
      
 200 
     | 
    
         
            +
                    # allow for softclamp the gradient norms for storing memories
         
     | 
| 
      
 201 
     | 
    
         
            +
             
     | 
| 
      
 202 
     | 
    
         
            +
                    self.max_grad_norm = max_grad_norm
         
     | 
| 
      
 203 
     | 
    
         
            +
             
     | 
| 
       175 
204 
     | 
    
         
             
                    # weight decay factor
         
     | 
| 
       176 
205 
     | 
    
         | 
| 
       177 
206 
     | 
    
         
             
                    self.to_decay_factor = nn.Sequential(
         
     | 
| 
         @@ -222,7 +251,8 @@ class NeuralMemory(Module): 
     | 
|
| 
       222 
251 
     | 
    
         | 
| 
       223 
252 
     | 
    
         
             
                    # pack batch and sequence dimension
         
     | 
| 
       224 
253 
     | 
    
         | 
| 
       225 
     | 
    
         
            -
                    adaptive_lr =  
     | 
| 
      
 254 
     | 
    
         
            +
                    adaptive_lr = self.to_adaptive_step(seq)
         
     | 
| 
      
 255 
     | 
    
         
            +
                    adaptive_lr = self.adaptive_step_transform(adaptive_lr)
         
     | 
| 
       226 
256 
     | 
    
         | 
| 
       227 
257 
     | 
    
         
             
                    adaptive_momentum = self.to_momentum(seq).sigmoid()
         
     | 
| 
       228 
258 
     | 
    
         
             
                    decay_factor = self.to_decay_factor(seq).sigmoid()
         
     | 
| 
         @@ -247,6 +277,11 @@ class NeuralMemory(Module): 
     | 
|
| 
       247 
277 
     | 
    
         | 
| 
       248 
278 
     | 
    
         
             
                    grads = TensorDict(grads)
         
     | 
| 
       249 
279 
     | 
    
         | 
| 
      
 280 
     | 
    
         
            +
                    # maybe softclamp grad norm
         
     | 
| 
      
 281 
     | 
    
         
            +
             
     | 
| 
      
 282 
     | 
    
         
            +
                    if exists(self.max_grad_norm):
         
     | 
| 
      
 283 
     | 
    
         
            +
                        grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))
         
     | 
| 
      
 284 
     | 
    
         
            +
             
     | 
| 
       250 
285 
     | 
    
         
             
                    # restore batch and sequence dimension
         
     | 
| 
       251 
286 
     | 
    
         | 
| 
       252 
287 
     | 
    
         
             
                    grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
         
     | 
| 
         @@ -62,11 +62,11 @@ def decode_tokens(tokens): 
     | 
|
| 
       62 
62 
     | 
    
         | 
| 
       63 
63 
     | 
    
         
             
            titans_neural_memory = NeuralMemory(
         
     | 
| 
       64 
64 
     | 
    
         
             
                dim = 384,
         
     | 
| 
       65 
     | 
    
         
            -
                chunk_size =  
     | 
| 
      
 65 
     | 
    
         
            +
                chunk_size = 4,
         
     | 
| 
       66 
66 
     | 
    
         
             
                pre_rmsnorm = True,
         
     | 
| 
       67 
67 
     | 
    
         
             
                post_rmsnorm = True,
         
     | 
| 
       68 
     | 
    
         
            -
                dim_head =  
     | 
| 
       69 
     | 
    
         
            -
                heads =  
     | 
| 
      
 68 
     | 
    
         
            +
                dim_head = 64,
         
     | 
| 
      
 69 
     | 
    
         
            +
                heads = 4,
         
     | 
| 
       70 
70 
     | 
    
         
             
                use_accelerated_scan = True,
         
     | 
| 
       71 
71 
     | 
    
         
             
                default_mlp_kwargs = dict(
         
     | 
| 
       72 
72 
     | 
    
         
             
                    depth = NEURAL_MEMORY_DEPTH
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |