titans-pytorch 0.0.29__tar.gz → 0.0.30__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.29
3
+ Version: 0.0.30
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.29"
3
+ version = "0.0.30"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -34,14 +34,20 @@ def test_titans_attn_memory():
34
34
 
35
35
  assert seq.shape == retrieved.shape
36
36
 
37
- def test_mac():
37
+ @pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
38
+ @pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
39
+ def test_mac(
40
+ num_persist_mem_tokens,
41
+ num_longterm_mem_tokens
42
+ ):
38
43
  from titans_pytorch.mac_transformer import MemoryAsContextTransformer
39
44
 
40
45
  transformer = MemoryAsContextTransformer(
41
46
  num_tokens = 256,
42
47
  dim = 256,
43
48
  depth = 2,
44
- num_persist_mem_tokens = 16,
49
+ num_persist_mem_tokens = num_persist_mem_tokens,
50
+ num_longterm_mem_tokens = num_longterm_mem_tokens,
45
51
  segment_len = 128,
46
52
  )
47
53
 
@@ -7,10 +7,9 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
- from einops import repeat
10
+ from einops import repeat, rearrange
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
-
14
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
15
14
 
16
15
  # absolute and relative positions
@@ -55,7 +54,8 @@ class SegmentedAttention(Module):
55
54
  self,
56
55
  dim,
57
56
  segment_len,
58
- num_persist_mem_tokens,
57
+ num_persist_mem_tokens = 0,
58
+ num_longterm_mem_tokens = 0,
59
59
  dim_head = 64,
60
60
  heads = 8,
61
61
  ):
@@ -70,25 +70,31 @@ class SegmentedAttention(Module):
70
70
  self.to_out = LinearNoBias(dim_inner, dim)
71
71
 
72
72
  self.segment_len = segment_len
73
+ self.num_longterm_mem_tokens = num_longterm_mem_tokens
74
+
75
+ total_segment_len = segment_len + num_longterm_mem_tokens
73
76
 
74
77
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
75
78
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
76
79
 
77
- self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
78
- self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
80
+ self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = total_segment_len)
81
+ self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = total_segment_len)
79
82
 
80
83
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
81
84
 
82
85
  def forward(self, seq):
86
+ segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
87
+ total_segment_len = segment_len + num_longterm_mem_tokens
88
+
83
89
  batch, seq_len = seq.shape[:2]
84
90
 
85
91
  # auto pad to multiple
86
92
  # todo - get rid of logic with flex attention
87
93
 
88
- need_segment = seq_len >= self.segment_len
94
+ need_segment = seq_len >= total_segment_len
89
95
 
90
96
  if need_segment:
91
- next_seq_len = round_up_multiple(seq_len, self.segment_len)
97
+ next_seq_len = round_up_multiple(seq_len, total_segment_len)
92
98
  padding = next_seq_len - seq_len
93
99
 
94
100
  if padding > 0:
@@ -139,7 +145,8 @@ class MemoryAsContextTransformer(Module):
139
145
  dim,
140
146
  depth,
141
147
  segment_len,
142
- num_persist_mem_tokens,
148
+ num_longterm_mem_tokens = 0,
149
+ num_persist_mem_tokens = 0,
143
150
  dim_head = 64,
144
151
  heads = 8,
145
152
  ff_mult = 4,
@@ -147,10 +154,18 @@ class MemoryAsContextTransformer(Module):
147
154
  ):
148
155
  super().__init__()
149
156
 
150
- self.segment_len = segment_len
157
+ self.token_emb = nn.Embedding(num_tokens, dim)
158
+
151
159
  self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
152
160
 
153
- self.token_emb = nn.Embedding(num_tokens, dim)
161
+ # long term mem tokens
162
+
163
+ self.segment_len = segment_len
164
+ self.num_longterm_mem_tokens = num_longterm_mem_tokens
165
+
166
+ self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
167
+
168
+ # hyper conection
154
169
 
155
170
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
156
171
 
@@ -162,6 +177,7 @@ class MemoryAsContextTransformer(Module):
162
177
  dim_head = dim_head,
163
178
  heads = heads,
164
179
  segment_len = segment_len,
180
+ num_longterm_mem_tokens = num_longterm_mem_tokens,
165
181
  num_persist_mem_tokens = num_persist_mem_tokens
166
182
  )
167
183
 
@@ -177,15 +193,42 @@ class MemoryAsContextTransformer(Module):
177
193
  self.to_logits = LinearNoBias(dim, num_tokens)
178
194
 
179
195
  def forward(self, x):
180
- seq_len, segment_len = x.shape[-1], self.segment_len
196
+
197
+ # math
198
+
199
+ batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
200
+
181
201
  windows = ceil(seq_len / segment_len)
202
+ total_segment_len = segment_len + num_longterm_mem_tokens
203
+
204
+ # token embedding
182
205
 
183
206
  x = self.token_emb(x)
184
207
 
208
+ # intersperse longterm memory
209
+
210
+ need_segment = seq_len >= segment_len
211
+
212
+ if need_segment:
213
+ next_seq_len = round_up_multiple(seq_len, segment_len)
214
+ padding = next_seq_len - seq_len
215
+
216
+ if padding > 0:
217
+ x = F.pad(x, (0, 0, 0, padding))
218
+
219
+ x = rearrange(x, 'b (w n) d -> (b w) n d', n = segment_len)
220
+
221
+ mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
222
+ x = torch.cat((mems, x), dim = -2)
223
+
224
+ if need_segment:
225
+ x = rearrange(x, '(b w) n d -> b (w n) d', b = batch)
226
+ x = x[:, :seq_len]
227
+
185
228
  # apply axial positional embedding
186
229
  # so intra and inter segment can be more easily discerned by the network
187
230
 
188
- pos_emb = self.axial_pos_emb((windows, segment_len), flatten = True)
231
+ pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
189
232
  x = x + pos_emb[:seq_len]
190
233
 
191
234
  # expand and reduce streams for hyper connections
@@ -198,21 +241,8 @@ class MemoryAsContextTransformer(Module):
198
241
 
199
242
  x = self.reduce_streams(x)
200
243
 
244
+ # to logits
245
+
201
246
  x = self.norm(x)
202
247
 
203
248
  return self.to_logits(x)
204
-
205
- # main
206
-
207
- if __name__ == '__main__':
208
- transformer = MemoryAsContextTransformer(
209
- num_tokens = 256,
210
- dim = 256,
211
- depth = 2,
212
- num_persist_mem_tokens = 16,
213
- segment_len = 128,
214
- )
215
-
216
- x = torch.randint(0, 256, (1, 1023))
217
-
218
- logits = transformer(x)
File without changes