titans-pytorch 0.0.29__tar.gz → 0.0.30__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.29
3
+ Version: 0.0.30
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.29"
3
+ version = "0.0.30"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -34,14 +34,20 @@ def test_titans_attn_memory():
34
34
 
35
35
  assert seq.shape == retrieved.shape
36
36
 
37
- def test_mac():
37
+ @pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
38
+ @pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
39
+ def test_mac(
40
+ num_persist_mem_tokens,
41
+ num_longterm_mem_tokens
42
+ ):
38
43
  from titans_pytorch.mac_transformer import MemoryAsContextTransformer
39
44
 
40
45
  transformer = MemoryAsContextTransformer(
41
46
  num_tokens = 256,
42
47
  dim = 256,
43
48
  depth = 2,
44
- num_persist_mem_tokens = 16,
49
+ num_persist_mem_tokens = num_persist_mem_tokens,
50
+ num_longterm_mem_tokens = num_longterm_mem_tokens,
45
51
  segment_len = 128,
46
52
  )
47
53
 
@@ -7,10 +7,9 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
- from einops import repeat
10
+ from einops import repeat, rearrange
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
-
14
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
15
14
 
16
15
  # absolute and relative positions
@@ -55,7 +54,8 @@ class SegmentedAttention(Module):
55
54
  self,
56
55
  dim,
57
56
  segment_len,
58
- num_persist_mem_tokens,
57
+ num_persist_mem_tokens = 0,
58
+ num_longterm_mem_tokens = 0,
59
59
  dim_head = 64,
60
60
  heads = 8,
61
61
  ):
@@ -70,25 +70,31 @@ class SegmentedAttention(Module):
70
70
  self.to_out = LinearNoBias(dim_inner, dim)
71
71
 
72
72
  self.segment_len = segment_len
73
+ self.num_longterm_mem_tokens = num_longterm_mem_tokens
74
+
75
+ total_segment_len = segment_len + num_longterm_mem_tokens
73
76
 
74
77
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
75
78
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
76
79
 
77
- self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
78
- self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
80
+ self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = total_segment_len)
81
+ self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = total_segment_len)
79
82
 
80
83
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
81
84
 
82
85
  def forward(self, seq):
86
+ segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
87
+ total_segment_len = segment_len + num_longterm_mem_tokens
88
+
83
89
  batch, seq_len = seq.shape[:2]
84
90
 
85
91
  # auto pad to multiple
86
92
  # todo - get rid of logic with flex attention
87
93
 
88
- need_segment = seq_len >= self.segment_len
94
+ need_segment = seq_len >= total_segment_len
89
95
 
90
96
  if need_segment:
91
- next_seq_len = round_up_multiple(seq_len, self.segment_len)
97
+ next_seq_len = round_up_multiple(seq_len, total_segment_len)
92
98
  padding = next_seq_len - seq_len
93
99
 
94
100
  if padding > 0:
@@ -139,7 +145,8 @@ class MemoryAsContextTransformer(Module):
139
145
  dim,
140
146
  depth,
141
147
  segment_len,
142
- num_persist_mem_tokens,
148
+ num_longterm_mem_tokens = 0,
149
+ num_persist_mem_tokens = 0,
143
150
  dim_head = 64,
144
151
  heads = 8,
145
152
  ff_mult = 4,
@@ -147,10 +154,18 @@ class MemoryAsContextTransformer(Module):
147
154
  ):
148
155
  super().__init__()
149
156
 
150
- self.segment_len = segment_len
157
+ self.token_emb = nn.Embedding(num_tokens, dim)
158
+
151
159
  self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
152
160
 
153
- self.token_emb = nn.Embedding(num_tokens, dim)
161
+ # long term mem tokens
162
+
163
+ self.segment_len = segment_len
164
+ self.num_longterm_mem_tokens = num_longterm_mem_tokens
165
+
166
+ self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
167
+
168
+ # hyper conection
154
169
 
155
170
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
156
171
 
@@ -162,6 +177,7 @@ class MemoryAsContextTransformer(Module):
162
177
  dim_head = dim_head,
163
178
  heads = heads,
164
179
  segment_len = segment_len,
180
+ num_longterm_mem_tokens = num_longterm_mem_tokens,
165
181
  num_persist_mem_tokens = num_persist_mem_tokens
166
182
  )
167
183
 
@@ -177,15 +193,42 @@ class MemoryAsContextTransformer(Module):
177
193
  self.to_logits = LinearNoBias(dim, num_tokens)
178
194
 
179
195
  def forward(self, x):
180
- seq_len, segment_len = x.shape[-1], self.segment_len
196
+
197
+ # math
198
+
199
+ batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
200
+
181
201
  windows = ceil(seq_len / segment_len)
202
+ total_segment_len = segment_len + num_longterm_mem_tokens
203
+
204
+ # token embedding
182
205
 
183
206
  x = self.token_emb(x)
184
207
 
208
+ # intersperse longterm memory
209
+
210
+ need_segment = seq_len >= segment_len
211
+
212
+ if need_segment:
213
+ next_seq_len = round_up_multiple(seq_len, segment_len)
214
+ padding = next_seq_len - seq_len
215
+
216
+ if padding > 0:
217
+ x = F.pad(x, (0, 0, 0, padding))
218
+
219
+ x = rearrange(x, 'b (w n) d -> (b w) n d', n = segment_len)
220
+
221
+ mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
222
+ x = torch.cat((mems, x), dim = -2)
223
+
224
+ if need_segment:
225
+ x = rearrange(x, '(b w) n d -> b (w n) d', b = batch)
226
+ x = x[:, :seq_len]
227
+
185
228
  # apply axial positional embedding
186
229
  # so intra and inter segment can be more easily discerned by the network
187
230
 
188
- pos_emb = self.axial_pos_emb((windows, segment_len), flatten = True)
231
+ pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
189
232
  x = x + pos_emb[:seq_len]
190
233
 
191
234
  # expand and reduce streams for hyper connections
@@ -198,21 +241,8 @@ class MemoryAsContextTransformer(Module):
198
241
 
199
242
  x = self.reduce_streams(x)
200
243
 
244
+ # to logits
245
+
201
246
  x = self.norm(x)
202
247
 
203
248
  return self.to_logits(x)
204
-
205
- # main
206
-
207
- if __name__ == '__main__':
208
- transformer = MemoryAsContextTransformer(
209
- num_tokens = 256,
210
- dim = 256,
211
- depth = 2,
212
- num_persist_mem_tokens = 16,
213
- segment_len = 128,
214
- )
215
-
216
- x = torch.randint(0, 256, (1, 1023))
217
-
218
- logits = transformer(x)
File without changes