titans-pytorch 0.0.32__py3-none-any.whl → 0.0.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +2 -0
 - titans_pytorch/mac_transformer.py +25 -4
 - {titans_pytorch-0.0.32.dist-info → titans_pytorch-0.0.35.dist-info}/METADATA +1 -1
 - titans_pytorch-0.0.35.dist-info/RECORD +9 -0
 - titans_pytorch-0.0.32.dist-info/RECORD +0 -9
 - {titans_pytorch-0.0.32.dist-info → titans_pytorch-0.0.35.dist-info}/WHEEL +0 -0
 - {titans_pytorch-0.0.32.dist-info → titans_pytorch-0.0.35.dist-info}/licenses/LICENSE +0 -0
 
    
        titans_pytorch/__init__.py
    CHANGED
    
    
| 
         @@ -50,13 +50,20 @@ def pad_and_segment_with_inverse(seq, segment_len): 
     | 
|
| 
       50 
50 
     | 
    
         
             
                next_seq_len_mult = round_up_multiple(seq_len, segment_len)
         
     | 
| 
       51 
51 
     | 
    
         | 
| 
       52 
52 
     | 
    
         
             
                padding = next_seq_len_mult - seq_len
         
     | 
| 
       53 
     | 
    
         
            -
                 
     | 
| 
      
 53 
     | 
    
         
            +
                needs_pad = padding > 0
         
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
      
 55 
     | 
    
         
            +
                if needs_pad:
         
     | 
| 
      
 56 
     | 
    
         
            +
                    seq = F.pad(seq, (0, 0, 0, padding))
         
     | 
| 
       54 
57 
     | 
    
         | 
| 
       55 
58 
     | 
    
         
             
                seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
         
     | 
| 
       56 
59 
     | 
    
         | 
| 
       57 
60 
     | 
    
         
             
                def inverse(out):
         
     | 
| 
       58 
61 
     | 
    
         
             
                    out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
         
     | 
| 
       59 
     | 
    
         
            -
             
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
                    if needs_pad:
         
     | 
| 
      
 64 
     | 
    
         
            +
                        out = out[:, :-padding]
         
     | 
| 
      
 65 
     | 
    
         
            +
             
     | 
| 
      
 66 
     | 
    
         
            +
                    return out
         
     | 
| 
       60 
67 
     | 
    
         | 
| 
       61 
68 
     | 
    
         
             
                return seq, inverse
         
     | 
| 
       62 
69 
     | 
    
         | 
| 
         @@ -226,7 +233,14 @@ class MemoryAsContextTransformer(Module): 
     | 
|
| 
       226 
233 
     | 
    
         | 
| 
       227 
234 
     | 
    
         
             
                    self.to_logits = LinearNoBias(dim, num_tokens)
         
     | 
| 
       228 
235 
     | 
    
         | 
| 
       229 
     | 
    
         
            -
                def forward( 
     | 
| 
      
 236 
     | 
    
         
            +
                def forward(
         
     | 
| 
      
 237 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 238 
     | 
    
         
            +
                    x,
         
     | 
| 
      
 239 
     | 
    
         
            +
                    return_loss = False
         
     | 
| 
      
 240 
     | 
    
         
            +
                ):
         
     | 
| 
      
 241 
     | 
    
         
            +
             
     | 
| 
      
 242 
     | 
    
         
            +
                    if return_loss:
         
     | 
| 
      
 243 
     | 
    
         
            +
                        x, labels = x[:, :-1], x[:, 1:]
         
     | 
| 
       230 
244 
     | 
    
         | 
| 
       231 
245 
     | 
    
         
             
                    # math
         
     | 
| 
       232 
246 
     | 
    
         | 
| 
         @@ -262,6 +276,7 @@ class MemoryAsContextTransformer(Module): 
     | 
|
| 
       262 
276 
     | 
    
         | 
| 
       263 
277 
     | 
    
         
             
                        if exists(maybe_neural_mem):
         
     | 
| 
       264 
278 
     | 
    
         
             
                            batch_streams = x.shape[0]
         
     | 
| 
      
 279 
     | 
    
         
            +
             
     | 
| 
       265 
280 
     | 
    
         
             
                            x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
         
     | 
| 
       266 
281 
     | 
    
         | 
| 
       267 
282 
     | 
    
         
             
                            longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
         
     | 
| 
         @@ -277,6 +292,7 @@ class MemoryAsContextTransformer(Module): 
     | 
|
| 
       277 
292 
     | 
    
         
             
                            x = inverse_segment(x)
         
     | 
| 
       278 
293 
     | 
    
         | 
| 
       279 
294 
     | 
    
         
             
                        x = attn(x)
         
     | 
| 
      
 295 
     | 
    
         
            +
             
     | 
| 
       280 
296 
     | 
    
         
             
                        x = ff(x)
         
     | 
| 
       281 
297 
     | 
    
         | 
| 
       282 
298 
     | 
    
         
             
                    x = self.reduce_streams(x)
         
     | 
| 
         @@ -293,4 +309,9 @@ class MemoryAsContextTransformer(Module): 
     | 
|
| 
       293 
309 
     | 
    
         | 
| 
       294 
310 
     | 
    
         
             
                    x = self.norm(x)
         
     | 
| 
       295 
311 
     | 
    
         | 
| 
       296 
     | 
    
         
            -
                     
     | 
| 
      
 312 
     | 
    
         
            +
                    logits = self.to_logits(x)
         
     | 
| 
      
 313 
     | 
    
         
            +
             
     | 
| 
      
 314 
     | 
    
         
            +
                    if not return_loss:
         
     | 
| 
      
 315 
     | 
    
         
            +
                        return logits
         
     | 
| 
      
 316 
     | 
    
         
            +
             
     | 
| 
      
 317 
     | 
    
         
            +
                    return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
         
     | 
| 
         @@ -0,0 +1,9 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
         
     | 
| 
      
 2 
     | 
    
         
            +
            titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
         
     | 
| 
      
 3 
     | 
    
         
            +
            titans_pytorch/mac_transformer.py,sha256=FGShQHD-dQQdQKKzvNS_jTC_FcikdqO_s3ZKOKfr_9E,8502
         
     | 
| 
      
 4 
     | 
    
         
            +
            titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
         
     | 
| 
      
 5 
     | 
    
         
            +
            titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
         
     | 
| 
      
 6 
     | 
    
         
            +
            titans_pytorch-0.0.35.dist-info/METADATA,sha256=jrhx-Bp1LqlOAV3jl4M70WTqq29ciz5lYWJvo2aoPE4,3938
         
     | 
| 
      
 7 
     | 
    
         
            +
            titans_pytorch-0.0.35.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
         
     | 
| 
      
 8 
     | 
    
         
            +
            titans_pytorch-0.0.35.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
         
     | 
| 
      
 9 
     | 
    
         
            +
            titans_pytorch-0.0.35.dist-info/RECORD,,
         
     | 
| 
         @@ -1,9 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
         
     | 
| 
       2 
     | 
    
         
            -
            titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
         
     | 
| 
       3 
     | 
    
         
            -
            titans_pytorch/mac_transformer.py,sha256=ohZWVhMBtpm0Iz3w5g7pD3WXSXrvhwzZvfRplwhe1Qo,8149
         
     | 
| 
       4 
     | 
    
         
            -
            titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
         
     | 
| 
       5 
     | 
    
         
            -
            titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
         
     | 
| 
       6 
     | 
    
         
            -
            titans_pytorch-0.0.32.dist-info/METADATA,sha256=9X9nWfgIVS-9XIeLHQY53HXSMA6rMemPfyVC2bRrJOQ,3938
         
     | 
| 
       7 
     | 
    
         
            -
            titans_pytorch-0.0.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
         
     | 
| 
       8 
     | 
    
         
            -
            titans_pytorch-0.0.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
         
     | 
| 
       9 
     | 
    
         
            -
            titans_pytorch-0.0.32.dist-info/RECORD,,
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |