titans-pytorch 0.0.34__tar.gz → 0.0.36__tar.gz

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

Files changed (20) hide show
  1. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/PKG-INFO +1 -1
  2. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/pyproject.toml +1 -1
  3. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/titans_pytorch/__init__.py +2 -0
  4. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/titans_pytorch/mac_transformer.py +10 -2
  5. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/train_mac.py +11 -7
  6. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/.github/workflows/python-publish.yml +0 -0
  7. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/.github/workflows/test.yaml +0 -0
  8. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/.gitignore +0 -0
  9. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/LICENSE +0 -0
  10. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/README.md +0 -0
  11. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/data/README.md +0 -0
  12. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/data/enwik8.gz +0 -0
  13. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/fig1.png +0 -0
  14. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/fig2.png +0 -0
  15. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/requirements.txt +0 -0
  16. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/tests/test_titans.py +0 -0
  17. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/titans_pytorch/associative_scan.py +0 -0
  18. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/titans_pytorch/titans.py +0 -0
  19. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/titans_pytorch/titans_attn_memory.py +0 -0
  20. {titans_pytorch-0.0.34 → titans_pytorch-0.0.36}/train.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.34
3
+ Version: 0.0.36
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.34"
3
+ version = "0.0.36"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -2,3 +2,5 @@ from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
4
  )
5
+
6
+ from titans_pytorch.mac_transformer import MemoryAsContextTransformer
@@ -185,7 +185,9 @@ class MemoryAsContextTransformer(Module):
185
185
  # long term mem tokens
186
186
 
187
187
  self.segment_len = segment_len
188
+
188
189
  self.num_longterm_mem_tokens = num_longterm_mem_tokens
190
+ has_longterm_mems = num_longterm_mem_tokens > 0
189
191
 
190
192
  self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
191
193
 
@@ -197,7 +199,11 @@ class MemoryAsContextTransformer(Module):
197
199
  self.neural_mem_layers = ModuleList([])
198
200
 
199
201
  layers = tuple(range(1, depth + 1))
200
- neural_memory_layers = set(default(neural_memory_layers, layers))
202
+
203
+ if not exists(neural_memory_layers):
204
+ neural_memory_layers = layers if has_longterm_mems else ()
205
+
206
+ assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
201
207
 
202
208
  for layer in layers:
203
209
 
@@ -205,7 +211,9 @@ class MemoryAsContextTransformer(Module):
205
211
 
206
212
  mem = None
207
213
 
208
- if num_longterm_mem_tokens > 0 and layer in neural_memory_layers:
214
+ if layer in neural_memory_layers:
215
+ assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
216
+
209
217
  mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
210
218
  mem = init_hyper_conn(dim = dim, branch = mem)
211
219
 
@@ -25,10 +25,12 @@ SEQ_LEN = 512
25
25
 
26
26
  PROJECT_NAME = 'titans-mac-transformer'
27
27
  WANDB_ONLINE = False # turn this on to pipe experiment to cloud
28
- GLOBAL_LAYERS = (2, 4)
29
28
  NEURAL_MEMORY_DEPTH = 2
30
- WINDOW_SIZE = 64
31
- RUN_NAME = 'mac'
29
+ NUM_PERSIST_MEM = 4
30
+ NUM_LONGTERM_MEM = 4
31
+ NEURAL_MEM_LAYERS = (2, 4, 6)
32
+ WINDOW_SIZE = 32
33
+ RUN_NAME = 'mac - 4 longterm mems, layers (2, 4, 6)'
32
34
 
33
35
  # wandb experiment tracker
34
36
 
@@ -57,12 +59,14 @@ model = MemoryAsContextTransformer(
57
59
  dim = 384,
58
60
  depth = 8,
59
61
  segment_len = WINDOW_SIZE,
60
- num_persist_mem_tokens = 16,
61
- num_longterm_mem_tokens = 16,
62
- neural_memory_layers = (3, 4),
62
+ num_persist_mem_tokens = NUM_PERSIST_MEM,
63
+ num_longterm_mem_tokens = NUM_LONGTERM_MEM,
64
+ neural_memory_layers = NEURAL_MEM_LAYERS,
63
65
  neural_memory_kwargs = dict(
64
66
  default_mlp_kwargs = dict(
65
- depth = NEURAL_MEMORY_DEPTH
67
+ depth = NEURAL_MEMORY_DEPTH,
68
+ dim_head = 64,
69
+ heads = 4
66
70
  )
67
71
  )
68
72
  ).cuda()
File without changes