titans-pytorch 0.0.32__py3-none-any.whl → 0.0.35__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- titans_pytorch/__init__.py +2 -0
- titans_pytorch/mac_transformer.py +25 -4
- {titans_pytorch-0.0.32.dist-info → titans_pytorch-0.0.35.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.35.dist-info/RECORD +9 -0
- titans_pytorch-0.0.32.dist-info/RECORD +0 -9
- {titans_pytorch-0.0.32.dist-info → titans_pytorch-0.0.35.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.32.dist-info → titans_pytorch-0.0.35.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
@@ -50,13 +50,20 @@ def pad_and_segment_with_inverse(seq, segment_len):
|
|
50
50
|
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
51
51
|
|
52
52
|
padding = next_seq_len_mult - seq_len
|
53
|
-
|
53
|
+
needs_pad = padding > 0
|
54
|
+
|
55
|
+
if needs_pad:
|
56
|
+
seq = F.pad(seq, (0, 0, 0, padding))
|
54
57
|
|
55
58
|
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
56
59
|
|
57
60
|
def inverse(out):
|
58
61
|
out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
|
59
|
-
|
62
|
+
|
63
|
+
if needs_pad:
|
64
|
+
out = out[:, :-padding]
|
65
|
+
|
66
|
+
return out
|
60
67
|
|
61
68
|
return seq, inverse
|
62
69
|
|
@@ -226,7 +233,14 @@ class MemoryAsContextTransformer(Module):
|
|
226
233
|
|
227
234
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
228
235
|
|
229
|
-
def forward(
|
236
|
+
def forward(
|
237
|
+
self,
|
238
|
+
x,
|
239
|
+
return_loss = False
|
240
|
+
):
|
241
|
+
|
242
|
+
if return_loss:
|
243
|
+
x, labels = x[:, :-1], x[:, 1:]
|
230
244
|
|
231
245
|
# math
|
232
246
|
|
@@ -262,6 +276,7 @@ class MemoryAsContextTransformer(Module):
|
|
262
276
|
|
263
277
|
if exists(maybe_neural_mem):
|
264
278
|
batch_streams = x.shape[0]
|
279
|
+
|
265
280
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
266
281
|
|
267
282
|
longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
|
@@ -277,6 +292,7 @@ class MemoryAsContextTransformer(Module):
|
|
277
292
|
x = inverse_segment(x)
|
278
293
|
|
279
294
|
x = attn(x)
|
295
|
+
|
280
296
|
x = ff(x)
|
281
297
|
|
282
298
|
x = self.reduce_streams(x)
|
@@ -293,4 +309,9 @@ class MemoryAsContextTransformer(Module):
|
|
293
309
|
|
294
310
|
x = self.norm(x)
|
295
311
|
|
296
|
-
|
312
|
+
logits = self.to_logits(x)
|
313
|
+
|
314
|
+
if not return_loss:
|
315
|
+
return logits
|
316
|
+
|
317
|
+
return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=FGShQHD-dQQdQKKzvNS_jTC_FcikdqO_s3ZKOKfr_9E,8502
|
4
|
+
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
5
|
+
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
+
titans_pytorch-0.0.35.dist-info/METADATA,sha256=jrhx-Bp1LqlOAV3jl4M70WTqq29ciz5lYWJvo2aoPE4,3938
|
7
|
+
titans_pytorch-0.0.35.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.35.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.35.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=ohZWVhMBtpm0Iz3w5g7pD3WXSXrvhwzZvfRplwhe1Qo,8149
|
4
|
-
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
5
|
-
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.32.dist-info/METADATA,sha256=9X9nWfgIVS-9XIeLHQY53HXSMA6rMemPfyVC2bRrJOQ,3938
|
7
|
-
titans_pytorch-0.0.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.0.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.0.32.dist-info/RECORD,,
|
File without changes
|
File without changes
|