titans-pytorch 0.0.26__tar.gz → 0.0.29__tar.gz

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.26
3
+ Version: 0.0.29
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -35,10 +35,12 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
+ Requires-Dist: axial-positional-embedding>=0.3.5
38
39
  Requires-Dist: einops>=0.8.0
39
40
  Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: hyper-connections>=0.1.8
41
42
  Requires-Dist: ninja
43
+ Requires-Dist: rotary-embedding-torch
42
44
  Requires-Dist: tensordict
43
45
  Requires-Dist: torch>=2.2
44
46
  Provides-Extra: examples
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.26"
3
+ version = "0.0.29"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -26,10 +26,12 @@ classifiers=[
26
26
 
27
27
  dependencies = [
28
28
  "accelerated-scan>=0.2.0",
29
+ "axial_positional_embedding>=0.3.5",
29
30
  "einx>=0.3.0",
30
31
  "einops>=0.8.0",
31
32
  "hyper-connections>=0.1.8",
32
33
  "Ninja",
34
+ "rotary-embedding-torch",
33
35
  "tensordict",
34
36
  "torch>=2.2",
35
37
  ]
@@ -33,3 +33,19 @@ def test_titans_attn_memory():
33
33
  retrieved = mem(seq)
34
34
 
35
35
  assert seq.shape == retrieved.shape
36
+
37
+ def test_mac():
38
+ from titans_pytorch.mac_transformer import MemoryAsContextTransformer
39
+
40
+ transformer = MemoryAsContextTransformer(
41
+ num_tokens = 256,
42
+ dim = 256,
43
+ depth = 2,
44
+ num_persist_mem_tokens = 16,
45
+ segment_len = 128,
46
+ )
47
+
48
+ x = torch.randint(0, 256, (1, 1023))
49
+
50
+ logits = transformer(x)
51
+ assert logits.shape == (1, 1023, 256)
@@ -1,16 +1,23 @@
1
1
  from __future__ import annotations
2
- import math
2
+ from math import ceil
3
3
  from functools import partial
4
4
 
5
5
  import torch
6
- from torch import nn
6
+ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
+ from einops import repeat
10
11
  from einops.layers.torch import Rearrange
11
12
 
13
+
12
14
  from hyper_connections import get_init_and_expand_reduce_stream_functions
13
15
 
16
+ # absolute and relative positions
17
+
18
+ from axial_positional_embedding import ContinuousAxialPositionalEmbedding
19
+ from rotary_embedding_torch import RotaryEmbedding
20
+
14
21
  # constants
15
22
 
16
23
  LinearNoBias = partial(Linear, bias = False)
@@ -24,7 +31,7 @@ def default(v, d):
24
31
  return v if exists(v) else d
25
32
 
26
33
  def round_up_multiple(seq, mult):
27
- return math.ceil(seq / mult) * mult
34
+ return ceil(seq / mult) * mult
28
35
 
29
36
  # feedforward and attention
30
37
 
@@ -48,6 +55,7 @@ class SegmentedAttention(Module):
48
55
  self,
49
56
  dim,
50
57
  segment_len,
58
+ num_persist_mem_tokens,
51
59
  dim_head = 64,
52
60
  heads = 8,
53
61
  ):
@@ -56,6 +64,8 @@ class SegmentedAttention(Module):
56
64
 
57
65
  dim_inner = dim_head * heads
58
66
 
67
+ self.rotary_emb = RotaryEmbedding(dim_head)
68
+
59
69
  self.to_qkv = LinearNoBias(dim, dim_inner * 3)
60
70
  self.to_out = LinearNoBias(dim_inner, dim)
61
71
 
@@ -67,6 +77,7 @@ class SegmentedAttention(Module):
67
77
  self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
68
78
  self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
69
79
 
80
+ self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
70
81
 
71
82
  def forward(self, seq):
72
83
  batch, seq_len = seq.shape[:2]
@@ -92,6 +103,21 @@ class SegmentedAttention(Module):
92
103
  q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
93
104
  q, k, v = map(self.split_heads, (q, k, v))
94
105
 
106
+ # take care of persistent memory key / values
107
+
108
+ pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
109
+
110
+ # relative positions
111
+
112
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
113
+
114
+ # persistent memory
115
+
116
+ k = cat((pmk, k), dim = -2)
117
+ v = cat((pmv, v), dim = -2)
118
+
119
+ # sdpa
120
+
95
121
  out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
96
122
 
97
123
  out = self.merge_heads(out)
@@ -113,6 +139,7 @@ class MemoryAsContextTransformer(Module):
113
139
  dim,
114
140
  depth,
115
141
  segment_len,
142
+ num_persist_mem_tokens,
116
143
  dim_head = 64,
117
144
  heads = 8,
118
145
  ff_mult = 4,
@@ -120,6 +147,9 @@ class MemoryAsContextTransformer(Module):
120
147
  ):
121
148
  super().__init__()
122
149
 
150
+ self.segment_len = segment_len
151
+ self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
152
+
123
153
  self.token_emb = nn.Embedding(num_tokens, dim)
124
154
 
125
155
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
@@ -127,7 +157,14 @@ class MemoryAsContextTransformer(Module):
127
157
  self.layers = ModuleList([])
128
158
 
129
159
  for _ in range(depth):
130
- attn = SegmentedAttention(dim = dim, dim_head = dim_head, heads = heads, segment_len = segment_len)
160
+ attn = SegmentedAttention(
161
+ dim = dim,
162
+ dim_head = dim_head,
163
+ heads = heads,
164
+ segment_len = segment_len,
165
+ num_persist_mem_tokens = num_persist_mem_tokens
166
+ )
167
+
131
168
  ff = FeedForward(dim = dim, mult = ff_mult)
132
169
 
133
170
  self.layers.append(ModuleList([
@@ -140,9 +177,19 @@ class MemoryAsContextTransformer(Module):
140
177
  self.to_logits = LinearNoBias(dim, num_tokens)
141
178
 
142
179
  def forward(self, x):
180
+ seq_len, segment_len = x.shape[-1], self.segment_len
181
+ windows = ceil(seq_len / segment_len)
143
182
 
144
183
  x = self.token_emb(x)
145
184
 
185
+ # apply axial positional embedding
186
+ # so intra and inter segment can be more easily discerned by the network
187
+
188
+ pos_emb = self.axial_pos_emb((windows, segment_len), flatten = True)
189
+ x = x + pos_emb[:seq_len]
190
+
191
+ # expand and reduce streams for hyper connections
192
+
146
193
  x = self.expand_streams(x)
147
194
 
148
195
  for attn, ff in self.layers:
@@ -162,6 +209,7 @@ if __name__ == '__main__':
162
209
  num_tokens = 256,
163
210
  dim = 256,
164
211
  depth = 2,
212
+ num_persist_mem_tokens = 16,
165
213
  segment_len = 128,
166
214
  )
167
215
 
File without changes