titans-pytorch 0.0.29__tar.gz → 0.0.31__tar.gz

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.29
3
+ Version: 0.0.31
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.29"
3
+ version = "0.0.31"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -34,14 +34,20 @@ def test_titans_attn_memory():
34
34
 
35
35
  assert seq.shape == retrieved.shape
36
36
 
37
- def test_mac():
37
+ @pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
38
+ @pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
39
+ def test_mac(
40
+ num_persist_mem_tokens,
41
+ num_longterm_mem_tokens
42
+ ):
38
43
  from titans_pytorch.mac_transformer import MemoryAsContextTransformer
39
44
 
40
45
  transformer = MemoryAsContextTransformer(
41
46
  num_tokens = 256,
42
47
  dim = 256,
43
48
  depth = 2,
44
- num_persist_mem_tokens = 16,
49
+ num_persist_mem_tokens = num_persist_mem_tokens,
50
+ num_longterm_mem_tokens = num_longterm_mem_tokens,
45
51
  segment_len = 128,
46
52
  )
47
53
 
@@ -7,10 +7,9 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
- from einops import repeat
10
+ from einops import repeat, rearrange
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
-
14
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
15
14
 
16
15
  # absolute and relative positions
@@ -30,9 +29,33 @@ def exists(v):
30
29
  def default(v, d):
31
30
  return v if exists(v) else d
32
31
 
32
+ def identity(t):
33
+ return t
34
+
33
35
  def round_up_multiple(seq, mult):
34
36
  return ceil(seq / mult) * mult
35
37
 
38
+ def pad_and_segment_with_inverse(seq, segment_len):
39
+ batch, seq_len = seq.shape[:2]
40
+
41
+ need_segment = seq_len >= segment_len
42
+
43
+ if not need_segment:
44
+ return seq, identity
45
+
46
+ next_seq_len_mult = round_up_multiple(seq_len, segment_len)
47
+
48
+ padding = next_seq_len_mult - seq_len
49
+ seq = F.pad(seq, (0, 0, 0, padding))
50
+
51
+ seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
52
+
53
+ def inverse(out):
54
+ out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
55
+ return out[:, :-padding]
56
+
57
+ return seq, inverse
58
+
36
59
  # feedforward and attention
37
60
 
38
61
  class GEGLU(Module):
@@ -55,7 +78,8 @@ class SegmentedAttention(Module):
55
78
  self,
56
79
  dim,
57
80
  segment_len,
58
- num_persist_mem_tokens,
81
+ num_persist_mem_tokens = 0,
82
+ num_longterm_mem_tokens = 0,
59
83
  dim_head = 64,
60
84
  heads = 8,
61
85
  ):
@@ -70,31 +94,25 @@ class SegmentedAttention(Module):
70
94
  self.to_out = LinearNoBias(dim_inner, dim)
71
95
 
72
96
  self.segment_len = segment_len
97
+ self.num_longterm_mem_tokens = num_longterm_mem_tokens
98
+
99
+ total_segment_len = segment_len + num_longterm_mem_tokens
73
100
 
74
101
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
75
102
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
76
103
 
77
- self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
78
- self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
79
-
80
104
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
81
105
 
82
106
  def forward(self, seq):
107
+ segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
108
+ total_segment_len = segment_len + num_longterm_mem_tokens
109
+
83
110
  batch, seq_len = seq.shape[:2]
84
111
 
85
112
  # auto pad to multiple
86
113
  # todo - get rid of logic with flex attention
87
114
 
88
- need_segment = seq_len >= self.segment_len
89
-
90
- if need_segment:
91
- next_seq_len = round_up_multiple(seq_len, self.segment_len)
92
- padding = next_seq_len - seq_len
93
-
94
- if padding > 0:
95
- seq = F.pad(seq, (0, 0, 0, padding))
96
-
97
- seq = self.segment_seq(seq)
115
+ seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
98
116
 
99
117
  # attention
100
118
 
@@ -124,10 +142,9 @@ class SegmentedAttention(Module):
124
142
 
125
143
  out = self.to_out(out)
126
144
 
127
- if need_segment:
128
- out = self.merge_seq_back(out)
145
+ out = inverse_segment(out)
129
146
 
130
- return out[:, :seq_len]
147
+ return out
131
148
 
132
149
  # MAC transformer
133
150
 
@@ -139,7 +156,8 @@ class MemoryAsContextTransformer(Module):
139
156
  dim,
140
157
  depth,
141
158
  segment_len,
142
- num_persist_mem_tokens,
159
+ num_longterm_mem_tokens = 0,
160
+ num_persist_mem_tokens = 0,
143
161
  dim_head = 64,
144
162
  heads = 8,
145
163
  ff_mult = 4,
@@ -147,10 +165,18 @@ class MemoryAsContextTransformer(Module):
147
165
  ):
148
166
  super().__init__()
149
167
 
150
- self.segment_len = segment_len
168
+ self.token_emb = nn.Embedding(num_tokens, dim)
169
+
151
170
  self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
152
171
 
153
- self.token_emb = nn.Embedding(num_tokens, dim)
172
+ # long term mem tokens
173
+
174
+ self.segment_len = segment_len
175
+ self.num_longterm_mem_tokens = num_longterm_mem_tokens
176
+
177
+ self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
178
+
179
+ # hyper conection
154
180
 
155
181
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
156
182
 
@@ -162,6 +188,7 @@ class MemoryAsContextTransformer(Module):
162
188
  dim_head = dim_head,
163
189
  heads = heads,
164
190
  segment_len = segment_len,
191
+ num_longterm_mem_tokens = num_longterm_mem_tokens,
165
192
  num_persist_mem_tokens = num_persist_mem_tokens
166
193
  )
167
194
 
@@ -177,16 +204,32 @@ class MemoryAsContextTransformer(Module):
177
204
  self.to_logits = LinearNoBias(dim, num_tokens)
178
205
 
179
206
  def forward(self, x):
180
- seq_len, segment_len = x.shape[-1], self.segment_len
207
+
208
+ # math
209
+
210
+ batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
211
+
181
212
  windows = ceil(seq_len / segment_len)
213
+ total_segment_len = segment_len + num_longterm_mem_tokens
214
+
215
+ # token embedding
182
216
 
183
217
  x = self.token_emb(x)
184
218
 
219
+ # intersperse longterm memory
220
+
221
+ x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
222
+
223
+ mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
224
+ x = torch.cat((mems, x), dim = -2)
225
+
226
+ x = inverse_segment(x)
227
+
185
228
  # apply axial positional embedding
186
229
  # so intra and inter segment can be more easily discerned by the network
187
230
 
188
- pos_emb = self.axial_pos_emb((windows, segment_len), flatten = True)
189
- x = x + pos_emb[:seq_len]
231
+ pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
232
+ x = x + pos_emb[:x.shape[-2]]
190
233
 
191
234
  # expand and reduce streams for hyper connections
192
235
 
@@ -198,21 +241,16 @@ class MemoryAsContextTransformer(Module):
198
241
 
199
242
  x = self.reduce_streams(x)
200
243
 
201
- x = self.norm(x)
244
+ # excise out the memories
202
245
 
203
- return self.to_logits(x)
246
+ x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
204
247
 
205
- # main
248
+ x = x[:, self.num_longterm_mem_tokens:]
206
249
 
207
- if __name__ == '__main__':
208
- transformer = MemoryAsContextTransformer(
209
- num_tokens = 256,
210
- dim = 256,
211
- depth = 2,
212
- num_persist_mem_tokens = 16,
213
- segment_len = 128,
214
- )
250
+ x = inverse_segment(x)
251
+
252
+ # to logits
215
253
 
216
- x = torch.randint(0, 256, (1, 1023))
254
+ x = self.norm(x)
217
255
 
218
- logits = transformer(x)
256
+ return self.to_logits(x)
File without changes