titans-pytorch 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -537,7 +537,8 @@ class MemoryAsContextTransformer(Module):
537
537
  filter_kwargs: dict = dict(
538
538
  min_p = 0.1,
539
539
  ),
540
- show_progress = True
540
+ show_progress = True,
541
+ use_cache = False
541
542
  ):
542
543
  was_training = self.training
543
544
  self.eval()
@@ -547,8 +548,37 @@ class MemoryAsContextTransformer(Module):
547
548
 
548
549
  iter_wrap = tqdm.tqdm if show_progress else identity
549
550
 
551
+ # cache for axial pos, attention, and neural memory
552
+
553
+ cache = None
554
+ factorized_pos_emb = None
555
+
556
+ # precompute factorized pos emb
557
+
558
+ if use_cache:
559
+ round_up_seq_len = round_up_multiple(seq_len, self.segment_len)
560
+ longterm_mem_lens = (round_up_seq_len // self.segment_len) * self.num_longterm_mem_tokens
561
+ seq_len_with_mem = round_up_seq_len + longterm_mem_lens
562
+
563
+ axial_dims = self.axial_pos_emb.maybe_derive_outer_dim(seq_len_with_mem, (self.neural_memory_segment_len,))
564
+
565
+ factorized_pos_emb = self.axial_pos_emb(axial_dims, return_factorized = True)
566
+
567
+ # sample
568
+
550
569
  for _ in iter_wrap(range(sample_num_times)):
551
- logits = self.forward(out, disable_flex_attn = True)
570
+
571
+ logits, next_cache = self.forward(
572
+ out,
573
+ disable_flex_attn = True,
574
+ cache = cache,
575
+ return_cache = True,
576
+ factorized_pos_emb = factorized_pos_emb
577
+ )
578
+
579
+ if use_cache:
580
+ cache = next_cache
581
+
552
582
  logits = logits[:, -1]
553
583
 
554
584
  logits = filter_fn(logits, **filter_kwargs)
@@ -565,7 +595,10 @@ class MemoryAsContextTransformer(Module):
565
595
  x,
566
596
  return_loss = False,
567
597
  return_loss_breakdown = False,
568
- disable_flex_attn = False
598
+ disable_flex_attn = False,
599
+ cache = None,
600
+ return_cache = False,
601
+ factorized_pos_emb = None
569
602
  ):
570
603
 
571
604
  if return_loss:
@@ -593,7 +626,7 @@ class MemoryAsContextTransformer(Module):
593
626
  # apply axial positional embedding
594
627
  # so intra and inter segment can be more easily discerned by the network
595
628
 
596
- pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,))
629
+ pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,), factorized = factorized_pos_emb)
597
630
 
598
631
  x = x + pos_emb
599
632
 
@@ -651,7 +684,10 @@ class MemoryAsContextTransformer(Module):
651
684
  logits = self.to_logits(x)
652
685
 
653
686
  if not return_loss:
654
- return logits
687
+ if not return_cache:
688
+ return logits
689
+
690
+ return logits, cache
655
691
 
656
692
  ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
657
693
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.15
3
+ Version: 0.1.17
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=ajsX-djEUzstig5n99yF_NimRzKNfv0MSz-EIV-Fe1A,20393
4
+ titans_pytorch/titans.py,sha256=R0e25ly2uTHkHSZEb-9Eqb0DqtFq8wFBB8iH1T6bYVg,22240
5
+ titans_pytorch-0.1.17.dist-info/METADATA,sha256=E9nwWCKZLSqT9Mr85nrJQzinYpKZnkLeexeaYyOIqrU,6340
6
+ titans_pytorch-0.1.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.17.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=zxknstaI_Uz47Y8WvZ3S7geJ-TNdqKV5Rvj0Jlw8njs,19271
4
- titans_pytorch/titans.py,sha256=R0e25ly2uTHkHSZEb-9Eqb0DqtFq8wFBB8iH1T6bYVg,22240
5
- titans_pytorch-0.1.15.dist-info/METADATA,sha256=SnNsoK4obeOAWFPhQypYJfJWZ_abXKr7WCvLMqFdyg0,6340
6
- titans_pytorch-0.1.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.15.dist-info/RECORD,,