titans-pytorch 0.0.37__tar.gz → 0.0.39__tar.gz

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

Files changed (20) hide show
  1. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/PKG-INFO +2 -1
  2. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/pyproject.toml +2 -1
  3. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/titans_pytorch/mac_transformer.py +10 -5
  4. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/titans_pytorch/titans.py +8 -5
  5. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/.github/workflows/python-publish.yml +0 -0
  6. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/.github/workflows/test.yaml +0 -0
  7. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/.gitignore +0 -0
  8. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/LICENSE +0 -0
  9. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/README.md +0 -0
  10. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/data/README.md +0 -0
  11. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/data/enwik8.gz +0 -0
  12. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/fig1.png +0 -0
  13. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/fig2.png +0 -0
  14. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/requirements.txt +0 -0
  15. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/tests/test_titans.py +0 -0
  16. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/titans_pytorch/__init__.py +0 -0
  17. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/titans_pytorch/associative_scan.py +0 -0
  18. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/titans_pytorch/titans_attn_memory.py +0 -0
  19. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/train.py +0 -0
  20. {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/train_mac.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.37
3
+ Version: 0.0.39
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -43,6 +43,7 @@ Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
45
45
  Requires-Dist: torch>=2.2
46
+ Requires-Dist: x-transformers
46
47
  Provides-Extra: examples
47
48
  Requires-Dist: local-attention>=1.10.1; extra == 'examples'
48
49
  Requires-Dist: taylor-series-linear-attention; extra == 'examples'
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.37"
3
+ version = "0.0.39"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -34,6 +34,7 @@ dependencies = [
34
34
  "rotary-embedding-torch",
35
35
  "tensordict",
36
36
  "torch>=2.2",
37
+ "x-transformers"
37
38
  ]
38
39
 
39
40
  [project.urls]
@@ -7,7 +7,7 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
- from einops import repeat, rearrange, pack, unpack
10
+ from einops import einsum, repeat, rearrange, pack, unpack
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
@@ -16,6 +16,7 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
16
16
 
17
17
  from axial_positional_embedding import ContinuousAxialPositionalEmbedding
18
18
  from rotary_embedding_torch import RotaryEmbedding
19
+ from x_transformers.attend import Attend
19
20
 
20
21
  # proposed neural memory
21
22
 
@@ -93,6 +94,7 @@ class SegmentedAttention(Module):
93
94
  num_longterm_mem_tokens = 0,
94
95
  dim_head = 64,
95
96
  heads = 8,
97
+ attend_kwargs: dict = dict()
96
98
  ):
97
99
  super().__init__()
98
100
  self.norm = nn.RMSNorm(dim)
@@ -101,6 +103,8 @@ class SegmentedAttention(Module):
101
103
 
102
104
  self.rotary_emb = RotaryEmbedding(dim_head)
103
105
 
106
+ self.attend = Attend(causal = True, **attend_kwargs)
107
+
104
108
  self.to_qkv = LinearNoBias(dim, dim_inner * 3)
105
109
  self.to_out = LinearNoBias(dim_inner, dim)
106
110
 
@@ -145,9 +149,9 @@ class SegmentedAttention(Module):
145
149
  k = cat((pmk, k), dim = -2)
146
150
  v = cat((pmv, v), dim = -2)
147
151
 
148
- # sdpa
152
+ # attention
149
153
 
150
- out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
154
+ out, _ = self.attend(q, k, v)
151
155
 
152
156
  out = self.merge_heads(out)
153
157
 
@@ -288,7 +292,8 @@ class MemoryAsContextTransformer(Module):
288
292
  for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
289
293
 
290
294
  if exists(maybe_neural_mem):
291
- mems = maybe_neural_mem(mems)
295
+ x = maybe_neural_mem(x)
296
+
292
297
 
293
298
  x = attn(x)
294
299
 
@@ -300,7 +305,7 @@ class MemoryAsContextTransformer(Module):
300
305
 
301
306
  x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
302
307
 
303
- x, mem = unpack(x, mem_ps, 'b * d')
308
+ x, _ = unpack(x, mem_ps, 'b * d')
304
309
 
305
310
  x = inverse_segment(x)
306
311
 
@@ -27,9 +27,7 @@ n - sequence
27
27
  d - feature dimension
28
28
  c - intra-chunk
29
29
  """
30
-
31
- # constants
32
-
30
+ 7
33
31
  LinearNoBias = partial(Linear, bias = False)
34
32
 
35
33
  # functions
@@ -390,7 +388,10 @@ class NeuralMemory(Module):
390
388
 
391
389
  padding = next_seq_len - curtailed_seq_len
392
390
 
393
- seq = pad_at_dim(seq, (0, padding), dim = 1)
391
+ needs_pad = padding > 0
392
+
393
+ if needs_pad:
394
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
394
395
 
395
396
  # the parameters of the memory model stores the memories of the key / values
396
397
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
@@ -442,7 +443,9 @@ class NeuralMemory(Module):
442
443
  empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
443
444
  values = torch.cat((empty_memory_embeds, values), dim = -2)
444
445
 
445
- values = values[:, :-padding]
446
+ if needs_pad:
447
+ values = values[:, :-padding]
448
+
446
449
  return values
447
450
 
448
451
  def forward(
File without changes