titans-pytorch 0.0.36__tar.gz → 0.0.38__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/PKG-INFO +1 -1
  2. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/pyproject.toml +1 -1
  3. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/mac_transformer.py +10 -18
  4. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/titans.py +9 -6
  5. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/train_mac.py +5 -5
  6. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/.github/workflows/python-publish.yml +0 -0
  7. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/.github/workflows/test.yaml +0 -0
  8. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/.gitignore +0 -0
  9. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/LICENSE +0 -0
  10. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/README.md +0 -0
  11. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/data/README.md +0 -0
  12. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/data/enwik8.gz +0 -0
  13. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/fig1.png +0 -0
  14. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/fig2.png +0 -0
  15. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/requirements.txt +0 -0
  16. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/tests/test_titans.py +0 -0
  17. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/__init__.py +0 -0
  18. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/associative_scan.py +0 -0
  19. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/titans_attn_memory.py +0 -0
  20. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/train.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.36
3
+ Version: 0.0.38
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.36"
3
+ version = "0.0.38"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -7,7 +7,7 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
- from einops import repeat, rearrange
10
+ from einops import repeat, rearrange, pack, unpack
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
@@ -214,7 +214,12 @@ class MemoryAsContextTransformer(Module):
214
214
  if layer in neural_memory_layers:
215
215
  assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
216
216
 
217
- mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
217
+ mem = NeuralMemory(
218
+ dim = dim,
219
+ chunk_size = num_longterm_mem_tokens + segment_len,
220
+ **neural_memory_kwargs
221
+ )
222
+
218
223
  mem = init_hyper_conn(dim = dim, branch = mem)
219
224
 
220
225
  self.neural_mem_layers.append(mem)
@@ -266,7 +271,7 @@ class MemoryAsContextTransformer(Module):
266
271
  x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
267
272
 
268
273
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
269
- x = cat((mems, x), dim = -2)
274
+ x, mem_ps = pack((x, mems), 'b * d')
270
275
 
271
276
  x = inverse_segment(x)
272
277
 
@@ -283,21 +288,8 @@ class MemoryAsContextTransformer(Module):
283
288
  for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
284
289
 
285
290
  if exists(maybe_neural_mem):
286
- batch_streams = x.shape[0]
287
-
288
- x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
289
-
290
- longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
291
-
292
- longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
293
-
294
- longterm_mems = maybe_neural_mem(longterm_mems)
295
-
296
- longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
297
-
298
- x = cat((longterm_mems, x), dim = -2)
291
+ x = maybe_neural_mem(x)
299
292
 
300
- x = inverse_segment(x)
301
293
 
302
294
  x = attn(x)
303
295
 
@@ -309,7 +301,7 @@ class MemoryAsContextTransformer(Module):
309
301
 
310
302
  x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
311
303
 
312
- x = x[:, num_longterm_mem_tokens:]
304
+ x, _ = unpack(x, mem_ps, 'b * d')
313
305
 
314
306
  x = inverse_segment(x)
315
307
 
@@ -27,9 +27,7 @@ n - sequence
27
27
  d - feature dimension
28
28
  c - intra-chunk
29
29
  """
30
-
31
- # constants
32
-
30
+ 7
33
31
  LinearNoBias = partial(Linear, bias = False)
34
32
 
35
33
  # functions
@@ -132,7 +130,7 @@ class NeuralMemory(Module):
132
130
  max_grad_norm: float | None = None,
133
131
  use_accelerated_scan = False,
134
132
  default_mlp_kwargs: dict = dict(
135
- depth = 4
133
+ depth = 2
136
134
  )
137
135
  ):
138
136
  super().__init__()
@@ -390,7 +388,10 @@ class NeuralMemory(Module):
390
388
 
391
389
  padding = next_seq_len - curtailed_seq_len
392
390
 
393
- seq = pad_at_dim(seq, (0, padding), dim = 1)
391
+ needs_pad = padding > 0
392
+
393
+ if needs_pad:
394
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
394
395
 
395
396
  # the parameters of the memory model stores the memories of the key / values
396
397
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
@@ -442,7 +443,9 @@ class NeuralMemory(Module):
442
443
  empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
443
444
  values = torch.cat((empty_memory_embeds, values), dim = -2)
444
445
 
445
- values = values[:, :-padding]
446
+ if needs_pad:
447
+ values = values[:, :-padding]
448
+
446
449
  return values
447
450
 
448
451
  def forward(
@@ -24,13 +24,13 @@ SHOULD_GENERATE = False
24
24
  SEQ_LEN = 512
25
25
 
26
26
  PROJECT_NAME = 'titans-mac-transformer'
27
- WANDB_ONLINE = False # turn this on to pipe experiment to cloud
27
+ WANDB_ONLINE = True # turn this on to pipe experiment to cloud
28
28
  NEURAL_MEMORY_DEPTH = 2
29
29
  NUM_PERSIST_MEM = 4
30
30
  NUM_LONGTERM_MEM = 4
31
- NEURAL_MEM_LAYERS = (2, 4, 6)
31
+ NEURAL_MEM_LAYERS = (4,)
32
32
  WINDOW_SIZE = 32
33
- RUN_NAME = 'mac - 4 longterm mems, layers (2, 4, 6)'
33
+ RUN_NAME = 'mac - 4 longterm mems, layers (4,)'
34
34
 
35
35
  # wandb experiment tracker
36
36
 
@@ -63,10 +63,10 @@ model = MemoryAsContextTransformer(
63
63
  num_longterm_mem_tokens = NUM_LONGTERM_MEM,
64
64
  neural_memory_layers = NEURAL_MEM_LAYERS,
65
65
  neural_memory_kwargs = dict(
66
+ dim_head = 64,
67
+ heads = 4,
66
68
  default_mlp_kwargs = dict(
67
69
  depth = NEURAL_MEMORY_DEPTH,
68
- dim_head = 64,
69
- heads = 4
70
70
  )
71
71
  )
72
72
  ).cuda()
File without changes