titans-pytorch 0.0.29__tar.gz → 0.0.31__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.29
3
+ Version: 0.0.31
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.29"
3
+ version = "0.0.31"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -34,14 +34,20 @@ def test_titans_attn_memory():
34
34
 
35
35
  assert seq.shape == retrieved.shape
36
36
 
37
- def test_mac():
37
+ @pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
38
+ @pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
39
+ def test_mac(
40
+ num_persist_mem_tokens,
41
+ num_longterm_mem_tokens
42
+ ):
38
43
  from titans_pytorch.mac_transformer import MemoryAsContextTransformer
39
44
 
40
45
  transformer = MemoryAsContextTransformer(
41
46
  num_tokens = 256,
42
47
  dim = 256,
43
48
  depth = 2,
44
- num_persist_mem_tokens = 16,
49
+ num_persist_mem_tokens = num_persist_mem_tokens,
50
+ num_longterm_mem_tokens = num_longterm_mem_tokens,
45
51
  segment_len = 128,
46
52
  )
47
53
 
@@ -7,10 +7,9 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
- from einops import repeat
10
+ from einops import repeat, rearrange
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
-
14
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
15
14
 
16
15
  # absolute and relative positions
@@ -30,9 +29,33 @@ def exists(v):
30
29
  def default(v, d):
31
30
  return v if exists(v) else d
32
31
 
32
+ def identity(t):
33
+ return t
34
+
33
35
  def round_up_multiple(seq, mult):
34
36
  return ceil(seq / mult) * mult
35
37
 
38
+ def pad_and_segment_with_inverse(seq, segment_len):
39
+ batch, seq_len = seq.shape[:2]
40
+
41
+ need_segment = seq_len >= segment_len
42
+
43
+ if not need_segment:
44
+ return seq, identity
45
+
46
+ next_seq_len_mult = round_up_multiple(seq_len, segment_len)
47
+
48
+ padding = next_seq_len_mult - seq_len
49
+ seq = F.pad(seq, (0, 0, 0, padding))
50
+
51
+ seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
52
+
53
+ def inverse(out):
54
+ out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
55
+ return out[:, :-padding]
56
+
57
+ return seq, inverse
58
+
36
59
  # feedforward and attention
37
60
 
38
61
  class GEGLU(Module):
@@ -55,7 +78,8 @@ class SegmentedAttention(Module):
55
78
  self,
56
79
  dim,
57
80
  segment_len,
58
- num_persist_mem_tokens,
81
+ num_persist_mem_tokens = 0,
82
+ num_longterm_mem_tokens = 0,
59
83
  dim_head = 64,
60
84
  heads = 8,
61
85
  ):
@@ -70,31 +94,25 @@ class SegmentedAttention(Module):
70
94
  self.to_out = LinearNoBias(dim_inner, dim)
71
95
 
72
96
  self.segment_len = segment_len
97
+ self.num_longterm_mem_tokens = num_longterm_mem_tokens
98
+
99
+ total_segment_len = segment_len + num_longterm_mem_tokens
73
100
 
74
101
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
75
102
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
76
103
 
77
- self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
78
- self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
79
-
80
104
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
81
105
 
82
106
  def forward(self, seq):
107
+ segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
108
+ total_segment_len = segment_len + num_longterm_mem_tokens
109
+
83
110
  batch, seq_len = seq.shape[:2]
84
111
 
85
112
  # auto pad to multiple
86
113
  # todo - get rid of logic with flex attention
87
114
 
88
- need_segment = seq_len >= self.segment_len
89
-
90
- if need_segment:
91
- next_seq_len = round_up_multiple(seq_len, self.segment_len)
92
- padding = next_seq_len - seq_len
93
-
94
- if padding > 0:
95
- seq = F.pad(seq, (0, 0, 0, padding))
96
-
97
- seq = self.segment_seq(seq)
115
+ seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
98
116
 
99
117
  # attention
100
118
 
@@ -124,10 +142,9 @@ class SegmentedAttention(Module):
124
142
 
125
143
  out = self.to_out(out)
126
144
 
127
- if need_segment:
128
- out = self.merge_seq_back(out)
145
+ out = inverse_segment(out)
129
146
 
130
- return out[:, :seq_len]
147
+ return out
131
148
 
132
149
  # MAC transformer
133
150
 
@@ -139,7 +156,8 @@ class MemoryAsContextTransformer(Module):
139
156
  dim,
140
157
  depth,
141
158
  segment_len,
142
- num_persist_mem_tokens,
159
+ num_longterm_mem_tokens = 0,
160
+ num_persist_mem_tokens = 0,
143
161
  dim_head = 64,
144
162
  heads = 8,
145
163
  ff_mult = 4,
@@ -147,10 +165,18 @@ class MemoryAsContextTransformer(Module):
147
165
  ):
148
166
  super().__init__()
149
167
 
150
- self.segment_len = segment_len
168
+ self.token_emb = nn.Embedding(num_tokens, dim)
169
+
151
170
  self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
152
171
 
153
- self.token_emb = nn.Embedding(num_tokens, dim)
172
+ # long term mem tokens
173
+
174
+ self.segment_len = segment_len
175
+ self.num_longterm_mem_tokens = num_longterm_mem_tokens
176
+
177
+ self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
178
+
179
+ # hyper conection
154
180
 
155
181
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
156
182
 
@@ -162,6 +188,7 @@ class MemoryAsContextTransformer(Module):
162
188
  dim_head = dim_head,
163
189
  heads = heads,
164
190
  segment_len = segment_len,
191
+ num_longterm_mem_tokens = num_longterm_mem_tokens,
165
192
  num_persist_mem_tokens = num_persist_mem_tokens
166
193
  )
167
194
 
@@ -177,16 +204,32 @@ class MemoryAsContextTransformer(Module):
177
204
  self.to_logits = LinearNoBias(dim, num_tokens)
178
205
 
179
206
  def forward(self, x):
180
- seq_len, segment_len = x.shape[-1], self.segment_len
207
+
208
+ # math
209
+
210
+ batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
211
+
181
212
  windows = ceil(seq_len / segment_len)
213
+ total_segment_len = segment_len + num_longterm_mem_tokens
214
+
215
+ # token embedding
182
216
 
183
217
  x = self.token_emb(x)
184
218
 
219
+ # intersperse longterm memory
220
+
221
+ x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
222
+
223
+ mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
224
+ x = torch.cat((mems, x), dim = -2)
225
+
226
+ x = inverse_segment(x)
227
+
185
228
  # apply axial positional embedding
186
229
  # so intra and inter segment can be more easily discerned by the network
187
230
 
188
- pos_emb = self.axial_pos_emb((windows, segment_len), flatten = True)
189
- x = x + pos_emb[:seq_len]
231
+ pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
232
+ x = x + pos_emb[:x.shape[-2]]
190
233
 
191
234
  # expand and reduce streams for hyper connections
192
235
 
@@ -198,21 +241,16 @@ class MemoryAsContextTransformer(Module):
198
241
 
199
242
  x = self.reduce_streams(x)
200
243
 
201
- x = self.norm(x)
244
+ # excise out the memories
202
245
 
203
- return self.to_logits(x)
246
+ x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
204
247
 
205
- # main
248
+ x = x[:, self.num_longterm_mem_tokens:]
206
249
 
207
- if __name__ == '__main__':
208
- transformer = MemoryAsContextTransformer(
209
- num_tokens = 256,
210
- dim = 256,
211
- depth = 2,
212
- num_persist_mem_tokens = 16,
213
- segment_len = 128,
214
- )
250
+ x = inverse_segment(x)
251
+
252
+ # to logits
215
253
 
216
- x = torch.randint(0, 256, (1, 1023))
254
+ x = self.norm(x)
217
255
 
218
- logits = transformer(x)
256
+ return self.to_logits(x)
File without changes