titans-pytorch 0.0.35__tar.gz → 0.0.37__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/PKG-INFO +1 -1
  2. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/pyproject.toml +1 -1
  3. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/titans_pytorch/mac_transformer.py +20 -21
  4. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/titans_pytorch/titans.py +1 -1
  5. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/train_mac.py +12 -8
  6. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/.github/workflows/python-publish.yml +0 -0
  7. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/.github/workflows/test.yaml +0 -0
  8. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/.gitignore +0 -0
  9. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/LICENSE +0 -0
  10. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/README.md +0 -0
  11. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/data/README.md +0 -0
  12. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/data/enwik8.gz +0 -0
  13. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/fig1.png +0 -0
  14. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/fig2.png +0 -0
  15. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/requirements.txt +0 -0
  16. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/tests/test_titans.py +0 -0
  17. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/titans_pytorch/__init__.py +0 -0
  18. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/titans_pytorch/associative_scan.py +0 -0
  19. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/titans_pytorch/titans_attn_memory.py +0 -0
  20. {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/train.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.35
3
+ Version: 0.0.37
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.35"
3
+ version = "0.0.37"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -7,7 +7,7 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
- from einops import repeat, rearrange
10
+ from einops import repeat, rearrange, pack, unpack
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
@@ -185,7 +185,9 @@ class MemoryAsContextTransformer(Module):
185
185
  # long term mem tokens
186
186
 
187
187
  self.segment_len = segment_len
188
+
188
189
  self.num_longterm_mem_tokens = num_longterm_mem_tokens
190
+ has_longterm_mems = num_longterm_mem_tokens > 0
189
191
 
190
192
  self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
191
193
 
@@ -197,7 +199,11 @@ class MemoryAsContextTransformer(Module):
197
199
  self.neural_mem_layers = ModuleList([])
198
200
 
199
201
  layers = tuple(range(1, depth + 1))
200
- neural_memory_layers = set(default(neural_memory_layers, layers))
202
+
203
+ if not exists(neural_memory_layers):
204
+ neural_memory_layers = layers if has_longterm_mems else ()
205
+
206
+ assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
201
207
 
202
208
  for layer in layers:
203
209
 
@@ -205,8 +211,15 @@ class MemoryAsContextTransformer(Module):
205
211
 
206
212
  mem = None
207
213
 
208
- if num_longterm_mem_tokens > 0 and layer in neural_memory_layers:
209
- mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
214
+ if layer in neural_memory_layers:
215
+ assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
216
+
217
+ mem = NeuralMemory(
218
+ dim = dim,
219
+ chunk_size = num_longterm_mem_tokens + segment_len,
220
+ **neural_memory_kwargs
221
+ )
222
+
210
223
  mem = init_hyper_conn(dim = dim, branch = mem)
211
224
 
212
225
  self.neural_mem_layers.append(mem)
@@ -258,7 +271,7 @@ class MemoryAsContextTransformer(Module):
258
271
  x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
259
272
 
260
273
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
261
- x = cat((mems, x), dim = -2)
274
+ x, mem_ps = pack((x, mems), 'b * d')
262
275
 
263
276
  x = inverse_segment(x)
264
277
 
@@ -275,21 +288,7 @@ class MemoryAsContextTransformer(Module):
275
288
  for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
276
289
 
277
290
  if exists(maybe_neural_mem):
278
- batch_streams = x.shape[0]
279
-
280
- x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
281
-
282
- longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
283
-
284
- longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
285
-
286
- longterm_mems = maybe_neural_mem(longterm_mems)
287
-
288
- longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
289
-
290
- x = cat((longterm_mems, x), dim = -2)
291
-
292
- x = inverse_segment(x)
291
+ mems = maybe_neural_mem(mems)
293
292
 
294
293
  x = attn(x)
295
294
 
@@ -301,7 +300,7 @@ class MemoryAsContextTransformer(Module):
301
300
 
302
301
  x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
303
302
 
304
- x = x[:, num_longterm_mem_tokens:]
303
+ x, mem = unpack(x, mem_ps, 'b * d')
305
304
 
306
305
  x = inverse_segment(x)
307
306
 
@@ -132,7 +132,7 @@ class NeuralMemory(Module):
132
132
  max_grad_norm: float | None = None,
133
133
  use_accelerated_scan = False,
134
134
  default_mlp_kwargs: dict = dict(
135
- depth = 4
135
+ depth = 2
136
136
  )
137
137
  ):
138
138
  super().__init__()
@@ -24,11 +24,13 @@ SHOULD_GENERATE = False
24
24
  SEQ_LEN = 512
25
25
 
26
26
  PROJECT_NAME = 'titans-mac-transformer'
27
- WANDB_ONLINE = False # turn this on to pipe experiment to cloud
28
- GLOBAL_LAYERS = (2, 4)
27
+ WANDB_ONLINE = True # turn this on to pipe experiment to cloud
29
28
  NEURAL_MEMORY_DEPTH = 2
30
- WINDOW_SIZE = 64
31
- RUN_NAME = 'mac'
29
+ NUM_PERSIST_MEM = 4
30
+ NUM_LONGTERM_MEM = 4
31
+ NEURAL_MEM_LAYERS = (4,)
32
+ WINDOW_SIZE = 32
33
+ RUN_NAME = 'mac - 4 longterm mems, layers (4,)'
32
34
 
33
35
  # wandb experiment tracker
34
36
 
@@ -57,12 +59,14 @@ model = MemoryAsContextTransformer(
57
59
  dim = 384,
58
60
  depth = 8,
59
61
  segment_len = WINDOW_SIZE,
60
- num_persist_mem_tokens = 16,
61
- num_longterm_mem_tokens = 16,
62
- neural_memory_layers = (3, 4),
62
+ num_persist_mem_tokens = NUM_PERSIST_MEM,
63
+ num_longterm_mem_tokens = NUM_LONGTERM_MEM,
64
+ neural_memory_layers = NEURAL_MEM_LAYERS,
63
65
  neural_memory_kwargs = dict(
66
+ dim_head = 64,
67
+ heads = 4,
64
68
  default_mlp_kwargs = dict(
65
- depth = NEURAL_MEMORY_DEPTH
69
+ depth = NEURAL_MEMORY_DEPTH,
66
70
  )
67
71
  )
68
72
  ).cuda()
File without changes