titans-pytorch 0.2.6__tar.gz → 0.2.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.6
3
+ Version: 0.2.8
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -27,7 +27,7 @@ License: MIT License
27
27
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
28
  SOFTWARE.
29
29
  License-File: LICENSE
30
- Keywords: artificial intelligence,deep learning,linear attention,neural memory module,test time training
30
+ Keywords: artificial intelligence,deep learning,linear attention,memory,test time training
31
31
  Classifier: Development Status :: 4 - Beta
32
32
  Classifier: Intended Audience :: Developers
33
33
  Classifier: License :: OSI Approved :: MIT License
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.2.6"
3
+ version = "0.2.8"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -11,9 +11,9 @@ license = { file = "LICENSE" }
11
11
  keywords = [
12
12
  'artificial intelligence',
13
13
  'deep learning',
14
- 'neural memory module',
15
14
  'test time training',
16
- 'linear attention'
15
+ 'linear attention',
16
+ 'memory',
17
17
  ]
18
18
 
19
19
  classifiers=[
@@ -24,9 +24,9 @@ def torch_default_dtype(dtype):
24
24
 
25
25
  # main test
26
26
 
27
- @pytest.mark.parametrize('seq_len', (32, 1024, 77))
27
+ @pytest.mark.parametrize('seq_len', (32, 512, 77))
28
28
  @pytest.mark.parametrize('silu', (False, True))
29
- @pytest.mark.parametrize('attn_pool_chunks', (False, True))
29
+ @pytest.mark.parametrize('chunk_size, attn_pool_chunks', ((64, True), (64, False), (1, False)))
30
30
  @pytest.mark.parametrize('momentum', (False, True))
31
31
  @pytest.mark.parametrize('qk_rmsnorm', (False, True))
32
32
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
@@ -35,6 +35,7 @@ def test_titans(
35
35
  seq_len,
36
36
  silu,
37
37
  attn_pool_chunks,
38
+ chunk_size,
38
39
  momentum,
39
40
  qk_rmsnorm,
40
41
  max_grad_norm,
@@ -42,7 +43,7 @@ def test_titans(
42
43
  ):
43
44
  mem = NeuralMemory(
44
45
  dim = 384,
45
- chunk_size = 64,
46
+ chunk_size = chunk_size,
46
47
  activation = nn.SiLU() if silu else None,
47
48
  attn_pool_chunks = attn_pool_chunks,
48
49
  max_grad_norm = max_grad_norm,
@@ -3,11 +3,33 @@ from torch import nn, cat
3
3
  import torch.nn.functional as F
4
4
  from torch.nn import Module, ModuleList, Parameter, ParameterList
5
5
 
6
+ from einops import rearrange
7
+
6
8
  # functions
7
9
 
8
10
  def l2norm(t):
9
11
  return F.normalize(t, dim = -1)
10
12
 
13
+ # norms
14
+
15
+ class LayerNorm(Module):
16
+ def __init__(
17
+ self,
18
+ dim
19
+ ):
20
+ super().__init__()
21
+
22
+ self.ln = nn.LayerNorm(dim, elementwise_affine = False)
23
+ self.gamma = Parameter(torch.zeros(dim))
24
+
25
+ def forward(self, x):
26
+ gamma = self.gamma
27
+
28
+ if gamma.ndim == 2:
29
+ gamma = rearrange(gamma, 'b d -> b 1 d')
30
+
31
+ return self.ln(x) * (gamma + 1.)
32
+
11
33
  # memory mlp proposed in TTT
12
34
 
13
35
  class MemoryMLP(Module):
@@ -19,6 +41,8 @@ class MemoryMLP(Module):
19
41
  super().__init__()
20
42
  self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
21
43
 
44
+ self.ln = LayerNorm(dim)
45
+
22
46
  for weight in self.weights:
23
47
  nn.init.xavier_uniform_(weight)
24
48
 
@@ -36,7 +60,7 @@ class MemoryMLP(Module):
36
60
 
37
61
  x = x @ weight
38
62
 
39
- return x + residual
63
+ return self.ln(x) + residual
40
64
 
41
65
  # memory mlp, but with gated residual + final projection
42
66
 
@@ -60,6 +84,8 @@ class GatedResidualMemoryMLP(Module):
60
84
 
61
85
  self.final_proj = Parameter(torch.randn(dim, dim))
62
86
 
87
+ self.ln = LayerNorm(dim)
88
+
63
89
  for param in self.parameters():
64
90
  nn.init.xavier_uniform_(param)
65
91
 
@@ -83,7 +109,7 @@ class GatedResidualMemoryMLP(Module):
83
109
 
84
110
  out = x @ self.final_proj
85
111
 
86
- return out + residual
112
+ return self.ln(out) + residual
87
113
 
88
114
  # memory mlp with factorized weights
89
115
  # so can tradeoff capacity for smaller chunk sizes
@@ -103,6 +129,8 @@ class FactorizedMemoryMLP(Module):
103
129
  ]) for _ in range(depth)
104
130
  ])
105
131
 
132
+ self.ln = LayerNorm(dim)
133
+
106
134
  for weight1, weight2 in self.weights:
107
135
  nn.init.xavier_uniform_(weight1)
108
136
  nn.init.xavier_uniform_(weight2)
@@ -121,7 +149,7 @@ class FactorizedMemoryMLP(Module):
121
149
 
122
150
  x = x @ weight1 @ weight2
123
151
 
124
- return x + residual
152
+ return self.ln(x) + residual
125
153
 
126
154
  # improvised attention as memory module
127
155
 
@@ -144,6 +172,8 @@ class MemoryAttention(Module):
144
172
  nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
145
173
  ])
146
174
 
175
+ self.ln = LayerNorm(dim)
176
+
147
177
  for weight in self.weights:
148
178
  nn.init.xavier_uniform_(weight)
149
179
 
@@ -168,4 +198,4 @@ class MemoryAttention(Module):
168
198
  h = F.gelu(x @ ffw1)
169
199
  ff_out = h @ ffw2
170
200
 
171
- return attn_out + ff_out + residual
201
+ return self.ln(attn_out + ff_out) + residual
@@ -51,6 +51,9 @@ def default(*args):
51
51
  return arg
52
52
  return None
53
53
 
54
+ def identity(t):
55
+ return t
56
+
54
57
  def xnor(x, y):
55
58
  return not (x ^ y)
56
59
 
@@ -64,9 +67,6 @@ def safe_cat(inputs, dim = -2):
64
67
 
65
68
  return cat(inputs, dim = dim)
66
69
 
67
- def identity(t):
68
- return t
69
-
70
70
  def dict_get_shape(td):
71
71
  return {k: v.shape for k, v in td.items()}
72
72
 
@@ -454,14 +454,14 @@ class NeuralMemory(Module):
454
454
 
455
455
  weights = TensorDict(weights)
456
456
 
457
- # allow for neural memory of a previous layer and the past to produce gradients that become the weights of the current one generating the surprise
458
- # think this is necessary otherwise the memory model is static (unless if paper is misunderstood)
459
- # improvise (or perhaps correcting to) a solution
457
+ # allow for neural memory of a previous layer to influence surprise of current layer
458
+
459
+ weights_for_surprise = weights
460
460
 
461
461
  if exists(prev_layer_updates):
462
462
  prev_layer_updates = TensorDict(prev_layer_updates)
463
463
 
464
- weights = weights + prev_layer_updates
464
+ weights_for_surprise = weights_for_surprise + prev_layer_updates
465
465
 
466
466
  per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
467
467
 
@@ -506,11 +506,11 @@ class NeuralMemory(Module):
506
506
  # flatten batch and time if surprise depends on previous layer memory model
507
507
 
508
508
  if exists(prev_layer_updates):
509
- weights = weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
509
+ weights_for_surprise = weights_for_surprise.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
510
510
 
511
511
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
512
512
 
513
- grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
513
+ grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
514
514
 
515
515
  grads = TensorDict(grads)
516
516
 
@@ -536,7 +536,15 @@ class NeuralMemory(Module):
536
536
 
537
537
  if not exists(past_state):
538
538
  empty_dict = {key: None for key in weights.keys()}
539
- past_state = (empty_dict, empty_dict)
539
+
540
+ # minibatch_init_weight corresponds to W0 in figure 7 of TTT paper
541
+
542
+ minibatch_init_weight = weights
543
+
544
+ if dict_get_shape(weights) == self.init_weight_shape:
545
+ minibatch_init_weight = weights.apply(lambda t: repeat(t, '... -> b 1 (...)', b = batch * heads))
546
+
547
+ past_state = (minibatch_init_weight, empty_dict)
540
548
 
541
549
  past_last_update, past_last_momentum = past_state
542
550
 
@@ -734,7 +742,7 @@ class NeuralMemory(Module):
734
742
 
735
743
  # retrieve
736
744
 
737
- retrieved = self.retrieve_memories(token, updates + weights, chunk_size = 1)
745
+ retrieved = self.retrieve_memories(token, weights, chunk_size = 1)
738
746
 
739
747
  # next state tuple
740
748
 
@@ -796,7 +804,7 @@ class NeuralMemory(Module):
796
804
 
797
805
  retrieved = self.retrieve_memories(
798
806
  seq,
799
- mem_model_weights + updates,
807
+ mem_model_weights,
800
808
  chunk_size = chunk_size,
801
809
  prev_layer_updates = prev_layer_updates
802
810
  )
File without changes
File without changes
File without changes
File without changes