titans-pytorch 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +11 -1
 - {titans_pytorch-0.0.20.dist-info → titans_pytorch-0.0.21.dist-info}/METADATA +1 -1
 - titans_pytorch-0.0.21.dist-info/RECORD +8 -0
 - titans_pytorch-0.0.20.dist-info/RECORD +0 -8
 - {titans_pytorch-0.0.20.dist-info → titans_pytorch-0.0.21.dist-info}/WHEEL +0 -0
 - {titans_pytorch-0.0.20.dist-info → titans_pytorch-0.0.21.dist-info}/licenses/LICENSE +0 -0
 
    
        titans_pytorch/titans.py
    CHANGED
    
    | 
         @@ -40,6 +40,9 @@ def exists(v): 
     | 
|
| 
       40 
40 
     | 
    
         
             
            def default(v, d):
         
     | 
| 
       41 
41 
     | 
    
         
             
                return v if exists(v) else d
         
     | 
| 
       42 
42 
     | 
    
         | 
| 
      
 43 
     | 
    
         
            +
            def identity(t):
         
     | 
| 
      
 44 
     | 
    
         
            +
                return t
         
     | 
| 
      
 45 
     | 
    
         
            +
             
     | 
| 
       43 
46 
     | 
    
         
             
            def round_down_multiple(seq, mult):
         
     | 
| 
       44 
47 
     | 
    
         
             
                return seq // mult * mult
         
     | 
| 
       45 
48 
     | 
    
         | 
| 
         @@ -97,6 +100,9 @@ class MemoryMLP(Module): 
     | 
|
| 
       97 
100 
     | 
    
         | 
| 
       98 
101 
     | 
    
         
             
            # main neural memory
         
     | 
| 
       99 
102 
     | 
    
         | 
| 
      
 103 
     | 
    
         
            +
            def default_adaptive_step_transform(adaptive_step):
         
     | 
| 
      
 104 
     | 
    
         
            +
                return torch.exp(adaptive_step.sigmoid() * -15) # from 1. - 1e-7
         
     | 
| 
      
 105 
     | 
    
         
            +
             
     | 
| 
       100 
106 
     | 
    
         
             
            def default_loss_fn(pred, target):
         
     | 
| 
       101 
107 
     | 
    
         
             
                return (pred - target).pow(2).mean(dim = -1).sum()
         
     | 
| 
       102 
108 
     | 
    
         | 
| 
         @@ -109,6 +115,7 @@ class NeuralMemory(Module): 
     | 
|
| 
       109 
115 
     | 
    
         
             
                    heads = 1,
         
     | 
| 
       110 
116 
     | 
    
         
             
                    model: Module | None = None,
         
     | 
| 
       111 
117 
     | 
    
         
             
                    store_memory_loss_fn: Callable = default_loss_fn,
         
     | 
| 
      
 118 
     | 
    
         
            +
                    adaptive_step_transform: Callable = default_adaptive_step_transform,
         
     | 
| 
       112 
119 
     | 
    
         
             
                    pre_rmsnorm = True,
         
     | 
| 
       113 
120 
     | 
    
         
             
                    post_rmsnorm = True,
         
     | 
| 
       114 
121 
     | 
    
         
             
                    max_grad_norm: float | None = None,
         
     | 
| 
         @@ -188,6 +195,8 @@ class NeuralMemory(Module): 
     | 
|
| 
       188 
195 
     | 
    
         
             
                        Rearrange('b n h -> (b h) n')
         
     | 
| 
       189 
196 
     | 
    
         
             
                    )
         
     | 
| 
       190 
197 
     | 
    
         | 
| 
      
 198 
     | 
    
         
            +
                    self.adaptive_step_transform = adaptive_step_transform
         
     | 
| 
      
 199 
     | 
    
         
            +
             
     | 
| 
       191 
200 
     | 
    
         
             
                    # allow for softclamp the gradient norms for storing memories
         
     | 
| 
       192 
201 
     | 
    
         | 
| 
       193 
202 
     | 
    
         
             
                    self.max_grad_norm = max_grad_norm
         
     | 
| 
         @@ -242,7 +251,8 @@ class NeuralMemory(Module): 
     | 
|
| 
       242 
251 
     | 
    
         | 
| 
       243 
252 
     | 
    
         
             
                    # pack batch and sequence dimension
         
     | 
| 
       244 
253 
     | 
    
         | 
| 
       245 
     | 
    
         
            -
                    adaptive_lr =  
     | 
| 
      
 254 
     | 
    
         
            +
                    adaptive_lr = self.to_adaptive_step(seq)
         
     | 
| 
      
 255 
     | 
    
         
            +
                    adaptive_lr = self.adaptive_step_transform(adaptive_lr)
         
     | 
| 
       246 
256 
     | 
    
         | 
| 
       247 
257 
     | 
    
         
             
                    adaptive_momentum = self.to_momentum(seq).sigmoid()
         
     | 
| 
       248 
258 
     | 
    
         
             
                    decay_factor = self.to_decay_factor(seq).sigmoid()
         
     | 
| 
         @@ -0,0 +1,8 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
         
     | 
| 
      
 2 
     | 
    
         
            +
            titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
         
     | 
| 
      
 3 
     | 
    
         
            +
            titans_pytorch/titans.py,sha256=rsgpGs-nnqvsvF6Mu2GfmfNlRgYfHAV9MCEmd1ohUUI,13687
         
     | 
| 
      
 4 
     | 
    
         
            +
            titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
         
     | 
| 
      
 5 
     | 
    
         
            +
            titans_pytorch-0.0.21.dist-info/METADATA,sha256=3z2LGviHUicQq4_ThlRasOUCDXIe0xX2aIhxtKLWI0Q,3811
         
     | 
| 
      
 6 
     | 
    
         
            +
            titans_pytorch-0.0.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
         
     | 
| 
      
 7 
     | 
    
         
            +
            titans_pytorch-0.0.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
         
     | 
| 
      
 8 
     | 
    
         
            +
            titans_pytorch-0.0.21.dist-info/RECORD,,
         
     | 
| 
         @@ -1,8 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
         
     | 
| 
       2 
     | 
    
         
            -
            titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
         
     | 
| 
       3 
     | 
    
         
            -
            titans_pytorch/titans.py,sha256=HpVjFy6jyzLGB_ilqjcYWGE-VtYmUrUwkXzmzqPrCXc,13370
         
     | 
| 
       4 
     | 
    
         
            -
            titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
         
     | 
| 
       5 
     | 
    
         
            -
            titans_pytorch-0.0.20.dist-info/METADATA,sha256=9qJWG-hwJ8IK9auhQV2XyEs54T0-LMvBAArF-iQ21IE,3811
         
     | 
| 
       6 
     | 
    
         
            -
            titans_pytorch-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
         
     | 
| 
       7 
     | 
    
         
            -
            titans_pytorch-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
         
     | 
| 
       8 
     | 
    
         
            -
            titans_pytorch-0.0.20.dist-info/RECORD,,
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |