titans-pytorch 0.0.30__py3-none-any.whl → 0.0.31__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- titans_pytorch/mac_transformer.py +38 -30
- {titans_pytorch-0.0.30.dist-info → titans_pytorch-0.0.31.dist-info}/METADATA +1 -1
- {titans_pytorch-0.0.30.dist-info → titans_pytorch-0.0.31.dist-info}/RECORD +5 -5
- {titans_pytorch-0.0.30.dist-info → titans_pytorch-0.0.31.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.30.dist-info → titans_pytorch-0.0.31.dist-info}/licenses/LICENSE +0 -0
@@ -29,9 +29,33 @@ def exists(v):
|
|
29
29
|
def default(v, d):
|
30
30
|
return v if exists(v) else d
|
31
31
|
|
32
|
+
def identity(t):
|
33
|
+
return t
|
34
|
+
|
32
35
|
def round_up_multiple(seq, mult):
|
33
36
|
return ceil(seq / mult) * mult
|
34
37
|
|
38
|
+
def pad_and_segment_with_inverse(seq, segment_len):
|
39
|
+
batch, seq_len = seq.shape[:2]
|
40
|
+
|
41
|
+
need_segment = seq_len >= segment_len
|
42
|
+
|
43
|
+
if not need_segment:
|
44
|
+
return seq, identity
|
45
|
+
|
46
|
+
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
47
|
+
|
48
|
+
padding = next_seq_len_mult - seq_len
|
49
|
+
seq = F.pad(seq, (0, 0, 0, padding))
|
50
|
+
|
51
|
+
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
52
|
+
|
53
|
+
def inverse(out):
|
54
|
+
out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
|
55
|
+
return out[:, :-padding]
|
56
|
+
|
57
|
+
return seq, inverse
|
58
|
+
|
35
59
|
# feedforward and attention
|
36
60
|
|
37
61
|
class GEGLU(Module):
|
@@ -77,9 +101,6 @@ class SegmentedAttention(Module):
|
|
77
101
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
78
102
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
79
103
|
|
80
|
-
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = total_segment_len)
|
81
|
-
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = total_segment_len)
|
82
|
-
|
83
104
|
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
84
105
|
|
85
106
|
def forward(self, seq):
|
@@ -91,16 +112,7 @@ class SegmentedAttention(Module):
|
|
91
112
|
# auto pad to multiple
|
92
113
|
# todo - get rid of logic with flex attention
|
93
114
|
|
94
|
-
|
95
|
-
|
96
|
-
if need_segment:
|
97
|
-
next_seq_len = round_up_multiple(seq_len, total_segment_len)
|
98
|
-
padding = next_seq_len - seq_len
|
99
|
-
|
100
|
-
if padding > 0:
|
101
|
-
seq = F.pad(seq, (0, 0, 0, padding))
|
102
|
-
|
103
|
-
seq = self.segment_seq(seq)
|
115
|
+
seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
|
104
116
|
|
105
117
|
# attention
|
106
118
|
|
@@ -130,10 +142,9 @@ class SegmentedAttention(Module):
|
|
130
142
|
|
131
143
|
out = self.to_out(out)
|
132
144
|
|
133
|
-
|
134
|
-
out = self.merge_seq_back(out)
|
145
|
+
out = inverse_segment(out)
|
135
146
|
|
136
|
-
return out
|
147
|
+
return out
|
137
148
|
|
138
149
|
# MAC transformer
|
139
150
|
|
@@ -207,29 +218,18 @@ class MemoryAsContextTransformer(Module):
|
|
207
218
|
|
208
219
|
# intersperse longterm memory
|
209
220
|
|
210
|
-
|
211
|
-
|
212
|
-
if need_segment:
|
213
|
-
next_seq_len = round_up_multiple(seq_len, segment_len)
|
214
|
-
padding = next_seq_len - seq_len
|
215
|
-
|
216
|
-
if padding > 0:
|
217
|
-
x = F.pad(x, (0, 0, 0, padding))
|
218
|
-
|
219
|
-
x = rearrange(x, 'b (w n) d -> (b w) n d', n = segment_len)
|
221
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
220
222
|
|
221
223
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
222
224
|
x = torch.cat((mems, x), dim = -2)
|
223
225
|
|
224
|
-
|
225
|
-
x = rearrange(x, '(b w) n d -> b (w n) d', b = batch)
|
226
|
-
x = x[:, :seq_len]
|
226
|
+
x = inverse_segment(x)
|
227
227
|
|
228
228
|
# apply axial positional embedding
|
229
229
|
# so intra and inter segment can be more easily discerned by the network
|
230
230
|
|
231
231
|
pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
|
232
|
-
x = x + pos_emb[:
|
232
|
+
x = x + pos_emb[:x.shape[-2]]
|
233
233
|
|
234
234
|
# expand and reduce streams for hyper connections
|
235
235
|
|
@@ -241,6 +241,14 @@ class MemoryAsContextTransformer(Module):
|
|
241
241
|
|
242
242
|
x = self.reduce_streams(x)
|
243
243
|
|
244
|
+
# excise out the memories
|
245
|
+
|
246
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
247
|
+
|
248
|
+
x = x[:, self.num_longterm_mem_tokens:]
|
249
|
+
|
250
|
+
x = inverse_segment(x)
|
251
|
+
|
244
252
|
# to logits
|
245
253
|
|
246
254
|
x = self.norm(x)
|
@@ -1,9 +1,9 @@
|
|
1
1
|
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=bq5RbCgA0GWLFHTrDTIKUSQhkkuCkdjEykOwjfHDs0M,6747
|
4
4
|
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
5
5
|
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.
|
7
|
-
titans_pytorch-0.0.
|
8
|
-
titans_pytorch-0.0.
|
9
|
-
titans_pytorch-0.0.
|
6
|
+
titans_pytorch-0.0.31.dist-info/METADATA,sha256=bN1fVL2S_vML1oqLIA92tvBhkVvnpQN11fU4e1QVI4s,3938
|
7
|
+
titans_pytorch-0.0.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.31.dist-info/RECORD,,
|
File without changes
|
File without changes
|