titans-pytorch 0.0.36__tar.gz → 0.0.38__tar.gz

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

Files changed (20) hide show
  1. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/PKG-INFO +1 -1
  2. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/pyproject.toml +1 -1
  3. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/mac_transformer.py +10 -18
  4. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/titans.py +9 -6
  5. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/train_mac.py +5 -5
  6. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/.github/workflows/python-publish.yml +0 -0
  7. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/.github/workflows/test.yaml +0 -0
  8. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/.gitignore +0 -0
  9. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/LICENSE +0 -0
  10. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/README.md +0 -0
  11. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/data/README.md +0 -0
  12. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/data/enwik8.gz +0 -0
  13. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/fig1.png +0 -0
  14. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/fig2.png +0 -0
  15. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/requirements.txt +0 -0
  16. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/tests/test_titans.py +0 -0
  17. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/__init__.py +0 -0
  18. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/associative_scan.py +0 -0
  19. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/titans_attn_memory.py +0 -0
  20. {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/train.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.36
3
+ Version: 0.0.38
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.36"
3
+ version = "0.0.38"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -7,7 +7,7 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
- from einops import repeat, rearrange
10
+ from einops import repeat, rearrange, pack, unpack
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
@@ -214,7 +214,12 @@ class MemoryAsContextTransformer(Module):
214
214
  if layer in neural_memory_layers:
215
215
  assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
216
216
 
217
- mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
217
+ mem = NeuralMemory(
218
+ dim = dim,
219
+ chunk_size = num_longterm_mem_tokens + segment_len,
220
+ **neural_memory_kwargs
221
+ )
222
+
218
223
  mem = init_hyper_conn(dim = dim, branch = mem)
219
224
 
220
225
  self.neural_mem_layers.append(mem)
@@ -266,7 +271,7 @@ class MemoryAsContextTransformer(Module):
266
271
  x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
267
272
 
268
273
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
269
- x = cat((mems, x), dim = -2)
274
+ x, mem_ps = pack((x, mems), 'b * d')
270
275
 
271
276
  x = inverse_segment(x)
272
277
 
@@ -283,21 +288,8 @@ class MemoryAsContextTransformer(Module):
283
288
  for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
284
289
 
285
290
  if exists(maybe_neural_mem):
286
- batch_streams = x.shape[0]
287
-
288
- x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
289
-
290
- longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
291
-
292
- longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
293
-
294
- longterm_mems = maybe_neural_mem(longterm_mems)
295
-
296
- longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
297
-
298
- x = cat((longterm_mems, x), dim = -2)
291
+ x = maybe_neural_mem(x)
299
292
 
300
- x = inverse_segment(x)
301
293
 
302
294
  x = attn(x)
303
295
 
@@ -309,7 +301,7 @@ class MemoryAsContextTransformer(Module):
309
301
 
310
302
  x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
311
303
 
312
- x = x[:, num_longterm_mem_tokens:]
304
+ x, _ = unpack(x, mem_ps, 'b * d')
313
305
 
314
306
  x = inverse_segment(x)
315
307
 
@@ -27,9 +27,7 @@ n - sequence
27
27
  d - feature dimension
28
28
  c - intra-chunk
29
29
  """
30
-
31
- # constants
32
-
30
+ 7
33
31
  LinearNoBias = partial(Linear, bias = False)
34
32
 
35
33
  # functions
@@ -132,7 +130,7 @@ class NeuralMemory(Module):
132
130
  max_grad_norm: float | None = None,
133
131
  use_accelerated_scan = False,
134
132
  default_mlp_kwargs: dict = dict(
135
- depth = 4
133
+ depth = 2
136
134
  )
137
135
  ):
138
136
  super().__init__()
@@ -390,7 +388,10 @@ class NeuralMemory(Module):
390
388
 
391
389
  padding = next_seq_len - curtailed_seq_len
392
390
 
393
- seq = pad_at_dim(seq, (0, padding), dim = 1)
391
+ needs_pad = padding > 0
392
+
393
+ if needs_pad:
394
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
394
395
 
395
396
  # the parameters of the memory model stores the memories of the key / values
396
397
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
@@ -442,7 +443,9 @@ class NeuralMemory(Module):
442
443
  empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
443
444
  values = torch.cat((empty_memory_embeds, values), dim = -2)
444
445
 
445
- values = values[:, :-padding]
446
+ if needs_pad:
447
+ values = values[:, :-padding]
448
+
446
449
  return values
447
450
 
448
451
  def forward(
@@ -24,13 +24,13 @@ SHOULD_GENERATE = False
24
24
  SEQ_LEN = 512
25
25
 
26
26
  PROJECT_NAME = 'titans-mac-transformer'
27
- WANDB_ONLINE = False # turn this on to pipe experiment to cloud
27
+ WANDB_ONLINE = True # turn this on to pipe experiment to cloud
28
28
  NEURAL_MEMORY_DEPTH = 2
29
29
  NUM_PERSIST_MEM = 4
30
30
  NUM_LONGTERM_MEM = 4
31
- NEURAL_MEM_LAYERS = (2, 4, 6)
31
+ NEURAL_MEM_LAYERS = (4,)
32
32
  WINDOW_SIZE = 32
33
- RUN_NAME = 'mac - 4 longterm mems, layers (2, 4, 6)'
33
+ RUN_NAME = 'mac - 4 longterm mems, layers (4,)'
34
34
 
35
35
  # wandb experiment tracker
36
36
 
@@ -63,10 +63,10 @@ model = MemoryAsContextTransformer(
63
63
  num_longterm_mem_tokens = NUM_LONGTERM_MEM,
64
64
  neural_memory_layers = NEURAL_MEM_LAYERS,
65
65
  neural_memory_kwargs = dict(
66
+ dim_head = 64,
67
+ heads = 4,
66
68
  default_mlp_kwargs = dict(
67
69
  depth = NEURAL_MEMORY_DEPTH,
68
- dim_head = 64,
69
- heads = 4
70
70
  )
71
71
  )
72
72
  ).cuda()
File without changes