titans-pytorch 0.0.30__py3-none-any.whl → 0.0.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +38 -30
- {titans_pytorch-0.0.30.dist-info → titans_pytorch-0.0.31.dist-info}/METADATA +1 -1
- {titans_pytorch-0.0.30.dist-info → titans_pytorch-0.0.31.dist-info}/RECORD +5 -5
- {titans_pytorch-0.0.30.dist-info → titans_pytorch-0.0.31.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.30.dist-info → titans_pytorch-0.0.31.dist-info}/licenses/LICENSE +0 -0
@@ -29,9 +29,33 @@ def exists(v):
|
|
29
29
|
def default(v, d):
|
30
30
|
return v if exists(v) else d
|
31
31
|
|
32
|
+
def identity(t):
|
33
|
+
return t
|
34
|
+
|
32
35
|
def round_up_multiple(seq, mult):
|
33
36
|
return ceil(seq / mult) * mult
|
34
37
|
|
38
|
+
def pad_and_segment_with_inverse(seq, segment_len):
|
39
|
+
batch, seq_len = seq.shape[:2]
|
40
|
+
|
41
|
+
need_segment = seq_len >= segment_len
|
42
|
+
|
43
|
+
if not need_segment:
|
44
|
+
return seq, identity
|
45
|
+
|
46
|
+
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
47
|
+
|
48
|
+
padding = next_seq_len_mult - seq_len
|
49
|
+
seq = F.pad(seq, (0, 0, 0, padding))
|
50
|
+
|
51
|
+
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
52
|
+
|
53
|
+
def inverse(out):
|
54
|
+
out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
|
55
|
+
return out[:, :-padding]
|
56
|
+
|
57
|
+
return seq, inverse
|
58
|
+
|
35
59
|
# feedforward and attention
|
36
60
|
|
37
61
|
class GEGLU(Module):
|
@@ -77,9 +101,6 @@ class SegmentedAttention(Module):
|
|
77
101
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
78
102
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
79
103
|
|
80
|
-
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = total_segment_len)
|
81
|
-
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = total_segment_len)
|
82
|
-
|
83
104
|
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
84
105
|
|
85
106
|
def forward(self, seq):
|
@@ -91,16 +112,7 @@ class SegmentedAttention(Module):
|
|
91
112
|
# auto pad to multiple
|
92
113
|
# todo - get rid of logic with flex attention
|
93
114
|
|
94
|
-
|
95
|
-
|
96
|
-
if need_segment:
|
97
|
-
next_seq_len = round_up_multiple(seq_len, total_segment_len)
|
98
|
-
padding = next_seq_len - seq_len
|
99
|
-
|
100
|
-
if padding > 0:
|
101
|
-
seq = F.pad(seq, (0, 0, 0, padding))
|
102
|
-
|
103
|
-
seq = self.segment_seq(seq)
|
115
|
+
seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
|
104
116
|
|
105
117
|
# attention
|
106
118
|
|
@@ -130,10 +142,9 @@ class SegmentedAttention(Module):
|
|
130
142
|
|
131
143
|
out = self.to_out(out)
|
132
144
|
|
133
|
-
|
134
|
-
out = self.merge_seq_back(out)
|
145
|
+
out = inverse_segment(out)
|
135
146
|
|
136
|
-
return out
|
147
|
+
return out
|
137
148
|
|
138
149
|
# MAC transformer
|
139
150
|
|
@@ -207,29 +218,18 @@ class MemoryAsContextTransformer(Module):
|
|
207
218
|
|
208
219
|
# intersperse longterm memory
|
209
220
|
|
210
|
-
|
211
|
-
|
212
|
-
if need_segment:
|
213
|
-
next_seq_len = round_up_multiple(seq_len, segment_len)
|
214
|
-
padding = next_seq_len - seq_len
|
215
|
-
|
216
|
-
if padding > 0:
|
217
|
-
x = F.pad(x, (0, 0, 0, padding))
|
218
|
-
|
219
|
-
x = rearrange(x, 'b (w n) d -> (b w) n d', n = segment_len)
|
221
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
220
222
|
|
221
223
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
222
224
|
x = torch.cat((mems, x), dim = -2)
|
223
225
|
|
224
|
-
|
225
|
-
x = rearrange(x, '(b w) n d -> b (w n) d', b = batch)
|
226
|
-
x = x[:, :seq_len]
|
226
|
+
x = inverse_segment(x)
|
227
227
|
|
228
228
|
# apply axial positional embedding
|
229
229
|
# so intra and inter segment can be more easily discerned by the network
|
230
230
|
|
231
231
|
pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
|
232
|
-
x = x + pos_emb[:
|
232
|
+
x = x + pos_emb[:x.shape[-2]]
|
233
233
|
|
234
234
|
# expand and reduce streams for hyper connections
|
235
235
|
|
@@ -241,6 +241,14 @@ class MemoryAsContextTransformer(Module):
|
|
241
241
|
|
242
242
|
x = self.reduce_streams(x)
|
243
243
|
|
244
|
+
# excise out the memories
|
245
|
+
|
246
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
247
|
+
|
248
|
+
x = x[:, self.num_longterm_mem_tokens:]
|
249
|
+
|
250
|
+
x = inverse_segment(x)
|
251
|
+
|
244
252
|
# to logits
|
245
253
|
|
246
254
|
x = self.norm(x)
|
@@ -1,9 +1,9 @@
|
|
1
1
|
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=bq5RbCgA0GWLFHTrDTIKUSQhkkuCkdjEykOwjfHDs0M,6747
|
4
4
|
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
5
5
|
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.
|
7
|
-
titans_pytorch-0.0.
|
8
|
-
titans_pytorch-0.0.
|
9
|
-
titans_pytorch-0.0.
|
6
|
+
titans_pytorch-0.0.31.dist-info/METADATA,sha256=bN1fVL2S_vML1oqLIA92tvBhkVvnpQN11fU4e1QVI4s,3938
|
7
|
+
titans_pytorch-0.0.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.31.dist-info/RECORD,,
|
File without changes
|
File without changes
|