titans-pytorch 0.0.26__tar.gz → 0.0.29__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.26
3
+ Version: 0.0.29
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -35,10 +35,12 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
+ Requires-Dist: axial-positional-embedding>=0.3.5
38
39
  Requires-Dist: einops>=0.8.0
39
40
  Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: hyper-connections>=0.1.8
41
42
  Requires-Dist: ninja
43
+ Requires-Dist: rotary-embedding-torch
42
44
  Requires-Dist: tensordict
43
45
  Requires-Dist: torch>=2.2
44
46
  Provides-Extra: examples
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.26"
3
+ version = "0.0.29"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -26,10 +26,12 @@ classifiers=[
26
26
 
27
27
  dependencies = [
28
28
  "accelerated-scan>=0.2.0",
29
+ "axial_positional_embedding>=0.3.5",
29
30
  "einx>=0.3.0",
30
31
  "einops>=0.8.0",
31
32
  "hyper-connections>=0.1.8",
32
33
  "Ninja",
34
+ "rotary-embedding-torch",
33
35
  "tensordict",
34
36
  "torch>=2.2",
35
37
  ]
@@ -33,3 +33,19 @@ def test_titans_attn_memory():
33
33
  retrieved = mem(seq)
34
34
 
35
35
  assert seq.shape == retrieved.shape
36
+
37
+ def test_mac():
38
+ from titans_pytorch.mac_transformer import MemoryAsContextTransformer
39
+
40
+ transformer = MemoryAsContextTransformer(
41
+ num_tokens = 256,
42
+ dim = 256,
43
+ depth = 2,
44
+ num_persist_mem_tokens = 16,
45
+ segment_len = 128,
46
+ )
47
+
48
+ x = torch.randint(0, 256, (1, 1023))
49
+
50
+ logits = transformer(x)
51
+ assert logits.shape == (1, 1023, 256)
@@ -1,16 +1,23 @@
1
1
  from __future__ import annotations
2
- import math
2
+ from math import ceil
3
3
  from functools import partial
4
4
 
5
5
  import torch
6
- from torch import nn
6
+ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
+ from einops import repeat
10
11
  from einops.layers.torch import Rearrange
11
12
 
13
+
12
14
  from hyper_connections import get_init_and_expand_reduce_stream_functions
13
15
 
16
+ # absolute and relative positions
17
+
18
+ from axial_positional_embedding import ContinuousAxialPositionalEmbedding
19
+ from rotary_embedding_torch import RotaryEmbedding
20
+
14
21
  # constants
15
22
 
16
23
  LinearNoBias = partial(Linear, bias = False)
@@ -24,7 +31,7 @@ def default(v, d):
24
31
  return v if exists(v) else d
25
32
 
26
33
  def round_up_multiple(seq, mult):
27
- return math.ceil(seq / mult) * mult
34
+ return ceil(seq / mult) * mult
28
35
 
29
36
  # feedforward and attention
30
37
 
@@ -48,6 +55,7 @@ class SegmentedAttention(Module):
48
55
  self,
49
56
  dim,
50
57
  segment_len,
58
+ num_persist_mem_tokens,
51
59
  dim_head = 64,
52
60
  heads = 8,
53
61
  ):
@@ -56,6 +64,8 @@ class SegmentedAttention(Module):
56
64
 
57
65
  dim_inner = dim_head * heads
58
66
 
67
+ self.rotary_emb = RotaryEmbedding(dim_head)
68
+
59
69
  self.to_qkv = LinearNoBias(dim, dim_inner * 3)
60
70
  self.to_out = LinearNoBias(dim_inner, dim)
61
71
 
@@ -67,6 +77,7 @@ class SegmentedAttention(Module):
67
77
  self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
68
78
  self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
69
79
 
80
+ self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
70
81
 
71
82
  def forward(self, seq):
72
83
  batch, seq_len = seq.shape[:2]
@@ -92,6 +103,21 @@ class SegmentedAttention(Module):
92
103
  q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
93
104
  q, k, v = map(self.split_heads, (q, k, v))
94
105
 
106
+ # take care of persistent memory key / values
107
+
108
+ pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
109
+
110
+ # relative positions
111
+
112
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
113
+
114
+ # persistent memory
115
+
116
+ k = cat((pmk, k), dim = -2)
117
+ v = cat((pmv, v), dim = -2)
118
+
119
+ # sdpa
120
+
95
121
  out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
96
122
 
97
123
  out = self.merge_heads(out)
@@ -113,6 +139,7 @@ class MemoryAsContextTransformer(Module):
113
139
  dim,
114
140
  depth,
115
141
  segment_len,
142
+ num_persist_mem_tokens,
116
143
  dim_head = 64,
117
144
  heads = 8,
118
145
  ff_mult = 4,
@@ -120,6 +147,9 @@ class MemoryAsContextTransformer(Module):
120
147
  ):
121
148
  super().__init__()
122
149
 
150
+ self.segment_len = segment_len
151
+ self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
152
+
123
153
  self.token_emb = nn.Embedding(num_tokens, dim)
124
154
 
125
155
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
@@ -127,7 +157,14 @@ class MemoryAsContextTransformer(Module):
127
157
  self.layers = ModuleList([])
128
158
 
129
159
  for _ in range(depth):
130
- attn = SegmentedAttention(dim = dim, dim_head = dim_head, heads = heads, segment_len = segment_len)
160
+ attn = SegmentedAttention(
161
+ dim = dim,
162
+ dim_head = dim_head,
163
+ heads = heads,
164
+ segment_len = segment_len,
165
+ num_persist_mem_tokens = num_persist_mem_tokens
166
+ )
167
+
131
168
  ff = FeedForward(dim = dim, mult = ff_mult)
132
169
 
133
170
  self.layers.append(ModuleList([
@@ -140,9 +177,19 @@ class MemoryAsContextTransformer(Module):
140
177
  self.to_logits = LinearNoBias(dim, num_tokens)
141
178
 
142
179
  def forward(self, x):
180
+ seq_len, segment_len = x.shape[-1], self.segment_len
181
+ windows = ceil(seq_len / segment_len)
143
182
 
144
183
  x = self.token_emb(x)
145
184
 
185
+ # apply axial positional embedding
186
+ # so intra and inter segment can be more easily discerned by the network
187
+
188
+ pos_emb = self.axial_pos_emb((windows, segment_len), flatten = True)
189
+ x = x + pos_emb[:seq_len]
190
+
191
+ # expand and reduce streams for hyper connections
192
+
146
193
  x = self.expand_streams(x)
147
194
 
148
195
  for attn, ff in self.layers:
@@ -162,6 +209,7 @@ if __name__ == '__main__':
162
209
  num_tokens = 256,
163
210
  dim = 256,
164
211
  depth = 2,
212
+ num_persist_mem_tokens = 16,
165
213
  segment_len = 128,
166
214
  )
167
215
 
File without changes