titans-pytorch 0.0.42__py3-none-any.whl → 0.0.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,4 +3,6 @@ from titans_pytorch.titans import (
3
3
  MemoryMLP,
4
4
  )
5
5
 
6
- from titans_pytorch.mac_transformer import MemoryAsContextTransformer
6
+ from titans_pytorch.mac_transformer import (
7
+ MemoryAsContextTransformer
8
+ )
titans_pytorch/titans.py CHANGED
@@ -425,11 +425,7 @@ class NeuralMemory(Module):
425
425
  next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
426
426
 
427
427
  padding = next_seq_len - curtailed_seq_len
428
-
429
- needs_pad = padding > 0
430
-
431
- if needs_pad:
432
- seq = pad_at_dim(seq, (0, padding), dim = 1)
428
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
433
429
 
434
430
  # the parameters of the memory model stores the memories of the key / values
435
431
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
@@ -481,10 +477,7 @@ class NeuralMemory(Module):
481
477
  empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
482
478
  values = torch.cat((empty_memory_embeds, values), dim = -2)
483
479
 
484
- if needs_pad:
485
- values = values[:, :-padding]
486
-
487
- return values
480
+ return values[:, :seq_len]
488
481
 
489
482
  def forward(
490
483
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.42
3
+ Version: 0.0.43
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -83,6 +83,26 @@ retrieved = mem(seq)
83
83
  assert seq.shape == retrieved.shape
84
84
  ```
85
85
 
86
+ A transformer with the `MAC` configuration can be used as
87
+
88
+ ```python
89
+ import torch
90
+ from titans_pytorch import MemoryAsContextTransformer
91
+
92
+ transformer = MemoryAsContextTransformer(
93
+ num_tokens = 256,
94
+ dim = 256,
95
+ depth = 2,
96
+ segment_len = 128, # local attention window size
97
+ num_persist_mem_tokens = 4,
98
+ num_longterm_mem_tokens = 16,
99
+ )
100
+
101
+ token_ids = torch.randint(0, 256, (1, 1023))
102
+
103
+ logits = transformer(token_ids) # (1, 1023, 256)
104
+ ```
105
+
86
106
  ## Experiments
87
107
 
88
108
  ```bash
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=kSdfWGWwEk6d0lbb0WLVKQwdmG8LAzDg36QZm7aIio0,9451
4
+ titans_pytorch/titans.py,sha256=qxQ8pZCz8GEDhKeJMEaeAEzH66GAGVBNaRdNam_-czg,15260
5
+ titans_pytorch-0.0.43.dist-info/METADATA,sha256=3Rlt_5CIeDUkYEK5tcLiWTseWv48gg4OH5vMoSVNS2w,4210
6
+ titans_pytorch-0.0.43.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.43.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.43.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=kSdfWGWwEk6d0lbb0WLVKQwdmG8LAzDg36QZm7aIio0,9451
4
- titans_pytorch/titans.py,sha256=eA7D9aqfGbtmC2SgGAQnfEVYp5Uza9uebEyDpVpjNQc,15372
5
- titans_pytorch-0.0.42.dist-info/METADATA,sha256=4lZBFMZPuQRDQGdTK-TWheytxEaQZQv7bdXO7MLBrwI,3744
6
- titans_pytorch-0.0.42.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.42.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.42.dist-info/RECORD,,