titans-pytorch 0.0.32__py3-none-any.whl → 0.0.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +2 -0
- titans_pytorch/mac_transformer.py +25 -4
- {titans_pytorch-0.0.32.dist-info → titans_pytorch-0.0.35.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.35.dist-info/RECORD +9 -0
- titans_pytorch-0.0.32.dist-info/RECORD +0 -9
- {titans_pytorch-0.0.32.dist-info → titans_pytorch-0.0.35.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.32.dist-info → titans_pytorch-0.0.35.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
@@ -50,13 +50,20 @@ def pad_and_segment_with_inverse(seq, segment_len):
|
|
50
50
|
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
51
51
|
|
52
52
|
padding = next_seq_len_mult - seq_len
|
53
|
-
|
53
|
+
needs_pad = padding > 0
|
54
|
+
|
55
|
+
if needs_pad:
|
56
|
+
seq = F.pad(seq, (0, 0, 0, padding))
|
54
57
|
|
55
58
|
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
56
59
|
|
57
60
|
def inverse(out):
|
58
61
|
out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
|
59
|
-
|
62
|
+
|
63
|
+
if needs_pad:
|
64
|
+
out = out[:, :-padding]
|
65
|
+
|
66
|
+
return out
|
60
67
|
|
61
68
|
return seq, inverse
|
62
69
|
|
@@ -226,7 +233,14 @@ class MemoryAsContextTransformer(Module):
|
|
226
233
|
|
227
234
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
228
235
|
|
229
|
-
def forward(
|
236
|
+
def forward(
|
237
|
+
self,
|
238
|
+
x,
|
239
|
+
return_loss = False
|
240
|
+
):
|
241
|
+
|
242
|
+
if return_loss:
|
243
|
+
x, labels = x[:, :-1], x[:, 1:]
|
230
244
|
|
231
245
|
# math
|
232
246
|
|
@@ -262,6 +276,7 @@ class MemoryAsContextTransformer(Module):
|
|
262
276
|
|
263
277
|
if exists(maybe_neural_mem):
|
264
278
|
batch_streams = x.shape[0]
|
279
|
+
|
265
280
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
266
281
|
|
267
282
|
longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
|
@@ -277,6 +292,7 @@ class MemoryAsContextTransformer(Module):
|
|
277
292
|
x = inverse_segment(x)
|
278
293
|
|
279
294
|
x = attn(x)
|
295
|
+
|
280
296
|
x = ff(x)
|
281
297
|
|
282
298
|
x = self.reduce_streams(x)
|
@@ -293,4 +309,9 @@ class MemoryAsContextTransformer(Module):
|
|
293
309
|
|
294
310
|
x = self.norm(x)
|
295
311
|
|
296
|
-
|
312
|
+
logits = self.to_logits(x)
|
313
|
+
|
314
|
+
if not return_loss:
|
315
|
+
return logits
|
316
|
+
|
317
|
+
return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=FGShQHD-dQQdQKKzvNS_jTC_FcikdqO_s3ZKOKfr_9E,8502
|
4
|
+
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
5
|
+
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
+
titans_pytorch-0.0.35.dist-info/METADATA,sha256=jrhx-Bp1LqlOAV3jl4M70WTqq29ciz5lYWJvo2aoPE4,3938
|
7
|
+
titans_pytorch-0.0.35.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.35.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.35.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=ohZWVhMBtpm0Iz3w5g7pD3WXSXrvhwzZvfRplwhe1Qo,8149
|
4
|
-
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
5
|
-
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.32.dist-info/METADATA,sha256=9X9nWfgIVS-9XIeLHQY53HXSMA6rMemPfyVC2bRrJOQ,3938
|
7
|
-
titans_pytorch-0.0.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.0.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.0.32.dist-info/RECORD,,
|
File without changes
|
File without changes
|