titans-pytorch 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +41 -5
- {titans_pytorch-0.1.15.dist-info → titans_pytorch-0.1.17.dist-info}/METADATA +1 -1
- titans_pytorch-0.1.17.dist-info/RECORD +8 -0
- titans_pytorch-0.1.15.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.15.dist-info → titans_pytorch-0.1.17.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.15.dist-info → titans_pytorch-0.1.17.dist-info}/licenses/LICENSE +0 -0
|
@@ -537,7 +537,8 @@ class MemoryAsContextTransformer(Module):
|
|
|
537
537
|
filter_kwargs: dict = dict(
|
|
538
538
|
min_p = 0.1,
|
|
539
539
|
),
|
|
540
|
-
show_progress = True
|
|
540
|
+
show_progress = True,
|
|
541
|
+
use_cache = False
|
|
541
542
|
):
|
|
542
543
|
was_training = self.training
|
|
543
544
|
self.eval()
|
|
@@ -547,8 +548,37 @@ class MemoryAsContextTransformer(Module):
|
|
|
547
548
|
|
|
548
549
|
iter_wrap = tqdm.tqdm if show_progress else identity
|
|
549
550
|
|
|
551
|
+
# cache for axial pos, attention, and neural memory
|
|
552
|
+
|
|
553
|
+
cache = None
|
|
554
|
+
factorized_pos_emb = None
|
|
555
|
+
|
|
556
|
+
# precompute factorized pos emb
|
|
557
|
+
|
|
558
|
+
if use_cache:
|
|
559
|
+
round_up_seq_len = round_up_multiple(seq_len, self.segment_len)
|
|
560
|
+
longterm_mem_lens = (round_up_seq_len // self.segment_len) * self.num_longterm_mem_tokens
|
|
561
|
+
seq_len_with_mem = round_up_seq_len + longterm_mem_lens
|
|
562
|
+
|
|
563
|
+
axial_dims = self.axial_pos_emb.maybe_derive_outer_dim(seq_len_with_mem, (self.neural_memory_segment_len,))
|
|
564
|
+
|
|
565
|
+
factorized_pos_emb = self.axial_pos_emb(axial_dims, return_factorized = True)
|
|
566
|
+
|
|
567
|
+
# sample
|
|
568
|
+
|
|
550
569
|
for _ in iter_wrap(range(sample_num_times)):
|
|
551
|
-
|
|
570
|
+
|
|
571
|
+
logits, next_cache = self.forward(
|
|
572
|
+
out,
|
|
573
|
+
disable_flex_attn = True,
|
|
574
|
+
cache = cache,
|
|
575
|
+
return_cache = True,
|
|
576
|
+
factorized_pos_emb = factorized_pos_emb
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
if use_cache:
|
|
580
|
+
cache = next_cache
|
|
581
|
+
|
|
552
582
|
logits = logits[:, -1]
|
|
553
583
|
|
|
554
584
|
logits = filter_fn(logits, **filter_kwargs)
|
|
@@ -565,7 +595,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
565
595
|
x,
|
|
566
596
|
return_loss = False,
|
|
567
597
|
return_loss_breakdown = False,
|
|
568
|
-
disable_flex_attn = False
|
|
598
|
+
disable_flex_attn = False,
|
|
599
|
+
cache = None,
|
|
600
|
+
return_cache = False,
|
|
601
|
+
factorized_pos_emb = None
|
|
569
602
|
):
|
|
570
603
|
|
|
571
604
|
if return_loss:
|
|
@@ -593,7 +626,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
593
626
|
# apply axial positional embedding
|
|
594
627
|
# so intra and inter segment can be more easily discerned by the network
|
|
595
628
|
|
|
596
|
-
pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,))
|
|
629
|
+
pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,), factorized = factorized_pos_emb)
|
|
597
630
|
|
|
598
631
|
x = x + pos_emb
|
|
599
632
|
|
|
@@ -651,7 +684,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
651
684
|
logits = self.to_logits(x)
|
|
652
685
|
|
|
653
686
|
if not return_loss:
|
|
654
|
-
|
|
687
|
+
if not return_cache:
|
|
688
|
+
return logits
|
|
689
|
+
|
|
690
|
+
return logits, cache
|
|
655
691
|
|
|
656
692
|
ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
|
657
693
|
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=ajsX-djEUzstig5n99yF_NimRzKNfv0MSz-EIV-Fe1A,20393
|
|
4
|
+
titans_pytorch/titans.py,sha256=R0e25ly2uTHkHSZEb-9Eqb0DqtFq8wFBB8iH1T6bYVg,22240
|
|
5
|
+
titans_pytorch-0.1.17.dist-info/METADATA,sha256=E9nwWCKZLSqT9Mr85nrJQzinYpKZnkLeexeaYyOIqrU,6340
|
|
6
|
+
titans_pytorch-0.1.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.1.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.1.17.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=zxknstaI_Uz47Y8WvZ3S7geJ-TNdqKV5Rvj0Jlw8njs,19271
|
|
4
|
-
titans_pytorch/titans.py,sha256=R0e25ly2uTHkHSZEb-9Eqb0DqtFq8wFBB8iH1T6bYVg,22240
|
|
5
|
-
titans_pytorch-0.1.15.dist-info/METADATA,sha256=SnNsoK4obeOAWFPhQypYJfJWZ_abXKr7WCvLMqFdyg0,6340
|
|
6
|
-
titans_pytorch-0.1.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.1.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.1.15.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|