titans-pytorch 0.0.29__py3-none-any.whl → 0.0.31__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

@@ -7,10 +7,9 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
- from einops import repeat
10
+ from einops import repeat, rearrange
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
-
14
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
15
14
 
16
15
  # absolute and relative positions
@@ -30,9 +29,33 @@ def exists(v):
30
29
  def default(v, d):
31
30
  return v if exists(v) else d
32
31
 
32
+ def identity(t):
33
+ return t
34
+
33
35
  def round_up_multiple(seq, mult):
34
36
  return ceil(seq / mult) * mult
35
37
 
38
+ def pad_and_segment_with_inverse(seq, segment_len):
39
+ batch, seq_len = seq.shape[:2]
40
+
41
+ need_segment = seq_len >= segment_len
42
+
43
+ if not need_segment:
44
+ return seq, identity
45
+
46
+ next_seq_len_mult = round_up_multiple(seq_len, segment_len)
47
+
48
+ padding = next_seq_len_mult - seq_len
49
+ seq = F.pad(seq, (0, 0, 0, padding))
50
+
51
+ seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
52
+
53
+ def inverse(out):
54
+ out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
55
+ return out[:, :-padding]
56
+
57
+ return seq, inverse
58
+
36
59
  # feedforward and attention
37
60
 
38
61
  class GEGLU(Module):
@@ -55,7 +78,8 @@ class SegmentedAttention(Module):
55
78
  self,
56
79
  dim,
57
80
  segment_len,
58
- num_persist_mem_tokens,
81
+ num_persist_mem_tokens = 0,
82
+ num_longterm_mem_tokens = 0,
59
83
  dim_head = 64,
60
84
  heads = 8,
61
85
  ):
@@ -70,31 +94,25 @@ class SegmentedAttention(Module):
70
94
  self.to_out = LinearNoBias(dim_inner, dim)
71
95
 
72
96
  self.segment_len = segment_len
97
+ self.num_longterm_mem_tokens = num_longterm_mem_tokens
98
+
99
+ total_segment_len = segment_len + num_longterm_mem_tokens
73
100
 
74
101
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
75
102
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
76
103
 
77
- self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
78
- self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
79
-
80
104
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
81
105
 
82
106
  def forward(self, seq):
107
+ segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
108
+ total_segment_len = segment_len + num_longterm_mem_tokens
109
+
83
110
  batch, seq_len = seq.shape[:2]
84
111
 
85
112
  # auto pad to multiple
86
113
  # todo - get rid of logic with flex attention
87
114
 
88
- need_segment = seq_len >= self.segment_len
89
-
90
- if need_segment:
91
- next_seq_len = round_up_multiple(seq_len, self.segment_len)
92
- padding = next_seq_len - seq_len
93
-
94
- if padding > 0:
95
- seq = F.pad(seq, (0, 0, 0, padding))
96
-
97
- seq = self.segment_seq(seq)
115
+ seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
98
116
 
99
117
  # attention
100
118
 
@@ -124,10 +142,9 @@ class SegmentedAttention(Module):
124
142
 
125
143
  out = self.to_out(out)
126
144
 
127
- if need_segment:
128
- out = self.merge_seq_back(out)
145
+ out = inverse_segment(out)
129
146
 
130
- return out[:, :seq_len]
147
+ return out
131
148
 
132
149
  # MAC transformer
133
150
 
@@ -139,7 +156,8 @@ class MemoryAsContextTransformer(Module):
139
156
  dim,
140
157
  depth,
141
158
  segment_len,
142
- num_persist_mem_tokens,
159
+ num_longterm_mem_tokens = 0,
160
+ num_persist_mem_tokens = 0,
143
161
  dim_head = 64,
144
162
  heads = 8,
145
163
  ff_mult = 4,
@@ -147,10 +165,18 @@ class MemoryAsContextTransformer(Module):
147
165
  ):
148
166
  super().__init__()
149
167
 
150
- self.segment_len = segment_len
168
+ self.token_emb = nn.Embedding(num_tokens, dim)
169
+
151
170
  self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
152
171
 
153
- self.token_emb = nn.Embedding(num_tokens, dim)
172
+ # long term mem tokens
173
+
174
+ self.segment_len = segment_len
175
+ self.num_longterm_mem_tokens = num_longterm_mem_tokens
176
+
177
+ self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
178
+
179
+ # hyper conection
154
180
 
155
181
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
156
182
 
@@ -162,6 +188,7 @@ class MemoryAsContextTransformer(Module):
162
188
  dim_head = dim_head,
163
189
  heads = heads,
164
190
  segment_len = segment_len,
191
+ num_longterm_mem_tokens = num_longterm_mem_tokens,
165
192
  num_persist_mem_tokens = num_persist_mem_tokens
166
193
  )
167
194
 
@@ -177,16 +204,32 @@ class MemoryAsContextTransformer(Module):
177
204
  self.to_logits = LinearNoBias(dim, num_tokens)
178
205
 
179
206
  def forward(self, x):
180
- seq_len, segment_len = x.shape[-1], self.segment_len
207
+
208
+ # math
209
+
210
+ batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
211
+
181
212
  windows = ceil(seq_len / segment_len)
213
+ total_segment_len = segment_len + num_longterm_mem_tokens
214
+
215
+ # token embedding
182
216
 
183
217
  x = self.token_emb(x)
184
218
 
219
+ # intersperse longterm memory
220
+
221
+ x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
222
+
223
+ mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
224
+ x = torch.cat((mems, x), dim = -2)
225
+
226
+ x = inverse_segment(x)
227
+
185
228
  # apply axial positional embedding
186
229
  # so intra and inter segment can be more easily discerned by the network
187
230
 
188
- pos_emb = self.axial_pos_emb((windows, segment_len), flatten = True)
189
- x = x + pos_emb[:seq_len]
231
+ pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
232
+ x = x + pos_emb[:x.shape[-2]]
190
233
 
191
234
  # expand and reduce streams for hyper connections
192
235
 
@@ -198,21 +241,16 @@ class MemoryAsContextTransformer(Module):
198
241
 
199
242
  x = self.reduce_streams(x)
200
243
 
201
- x = self.norm(x)
244
+ # excise out the memories
202
245
 
203
- return self.to_logits(x)
246
+ x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
204
247
 
205
- # main
248
+ x = x[:, self.num_longterm_mem_tokens:]
206
249
 
207
- if __name__ == '__main__':
208
- transformer = MemoryAsContextTransformer(
209
- num_tokens = 256,
210
- dim = 256,
211
- depth = 2,
212
- num_persist_mem_tokens = 16,
213
- segment_len = 128,
214
- )
250
+ x = inverse_segment(x)
251
+
252
+ # to logits
215
253
 
216
- x = torch.randint(0, 256, (1, 1023))
254
+ x = self.norm(x)
217
255
 
218
- logits = transformer(x)
256
+ return self.to_logits(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.29
3
+ Version: 0.0.31
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=JBrJah7gfQDPizYRcBvpUKinrd2I9KMB997f3RIR8TA,5568
3
+ titans_pytorch/mac_transformer.py,sha256=bq5RbCgA0GWLFHTrDTIKUSQhkkuCkdjEykOwjfHDs0M,6747
4
4
  titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.29.dist-info/METADATA,sha256=EhS4E9SAoqzDa0PIjZpQmUSYAo5IS-XePofWlZZnIS0,3938
7
- titans_pytorch-0.0.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.29.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.29.dist-info/RECORD,,
6
+ titans_pytorch-0.0.31.dist-info/METADATA,sha256=bN1fVL2S_vML1oqLIA92tvBhkVvnpQN11fU4e1QVI4s,3938
7
+ titans_pytorch-0.0.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.31.dist-info/RECORD,,