titans-pytorch 0.0.32__py3-none-any.whl → 0.0.35__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

@@ -2,3 +2,5 @@ from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
4
  )
5
+
6
+ from titans_pytorch.mac_transformer import MemoryAsContextTransformer
@@ -50,13 +50,20 @@ def pad_and_segment_with_inverse(seq, segment_len):
50
50
  next_seq_len_mult = round_up_multiple(seq_len, segment_len)
51
51
 
52
52
  padding = next_seq_len_mult - seq_len
53
- seq = F.pad(seq, (0, 0, 0, padding))
53
+ needs_pad = padding > 0
54
+
55
+ if needs_pad:
56
+ seq = F.pad(seq, (0, 0, 0, padding))
54
57
 
55
58
  seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
56
59
 
57
60
  def inverse(out):
58
61
  out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
59
- return out[:, :-padding]
62
+
63
+ if needs_pad:
64
+ out = out[:, :-padding]
65
+
66
+ return out
60
67
 
61
68
  return seq, inverse
62
69
 
@@ -226,7 +233,14 @@ class MemoryAsContextTransformer(Module):
226
233
 
227
234
  self.to_logits = LinearNoBias(dim, num_tokens)
228
235
 
229
- def forward(self, x):
236
+ def forward(
237
+ self,
238
+ x,
239
+ return_loss = False
240
+ ):
241
+
242
+ if return_loss:
243
+ x, labels = x[:, :-1], x[:, 1:]
230
244
 
231
245
  # math
232
246
 
@@ -262,6 +276,7 @@ class MemoryAsContextTransformer(Module):
262
276
 
263
277
  if exists(maybe_neural_mem):
264
278
  batch_streams = x.shape[0]
279
+
265
280
  x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
266
281
 
267
282
  longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
@@ -277,6 +292,7 @@ class MemoryAsContextTransformer(Module):
277
292
  x = inverse_segment(x)
278
293
 
279
294
  x = attn(x)
295
+
280
296
  x = ff(x)
281
297
 
282
298
  x = self.reduce_streams(x)
@@ -293,4 +309,9 @@ class MemoryAsContextTransformer(Module):
293
309
 
294
310
  x = self.norm(x)
295
311
 
296
- return self.to_logits(x)
312
+ logits = self.to_logits(x)
313
+
314
+ if not return_loss:
315
+ return logits
316
+
317
+ return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.32
3
+ Version: 0.0.35
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=FGShQHD-dQQdQKKzvNS_jTC_FcikdqO_s3ZKOKfr_9E,8502
4
+ titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
+ titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
+ titans_pytorch-0.0.35.dist-info/METADATA,sha256=jrhx-Bp1LqlOAV3jl4M70WTqq29ciz5lYWJvo2aoPE4,3938
7
+ titans_pytorch-0.0.35.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.35.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.35.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=ohZWVhMBtpm0Iz3w5g7pD3WXSXrvhwzZvfRplwhe1Qo,8149
4
- titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
- titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.32.dist-info/METADATA,sha256=9X9nWfgIVS-9XIeLHQY53HXSMA6rMemPfyVC2bRrJOQ,3938
7
- titans_pytorch-0.0.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.32.dist-info/RECORD,,