titans-pytorch 0.1.23__py3-none-any.whl → 0.1.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +42 -13
- titans_pytorch/titans.py +83 -54
- {titans_pytorch-0.1.23.dist-info → titans_pytorch-0.1.27.dist-info}/METADATA +2 -2
- titans_pytorch-0.1.27.dist-info/RECORD +8 -0
- titans_pytorch-0.1.23.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.23.dist-info → titans_pytorch-0.1.27.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.23.dist-info → titans_pytorch-0.1.27.dist-info}/licenses/LICENSE +0 -0
@@ -510,10 +510,7 @@ class MemoryAsContextTransformer(Module):
|
|
510
510
|
|
511
511
|
layers = tuple(range(1, depth + 1))
|
512
512
|
|
513
|
-
|
514
|
-
neural_memory_layers = layers if has_longterm_mems else ()
|
515
|
-
|
516
|
-
assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
|
513
|
+
neural_memory_layers = default(neural_memory_layers, layers)
|
517
514
|
|
518
515
|
# mem, attn, and feedforward layers
|
519
516
|
|
@@ -535,9 +532,10 @@ class MemoryAsContextTransformer(Module):
|
|
535
532
|
)
|
536
533
|
|
537
534
|
mem = None
|
535
|
+
mem_hyper_conn = None
|
538
536
|
|
539
537
|
if layer in neural_memory_layers:
|
540
|
-
|
538
|
+
mem_hyper_conn = init_hyper_conn(dim = dim, add_branch_out_to_residual = not neural_mem_gate_attn_output)
|
541
539
|
|
542
540
|
mem = NeuralMemory(
|
543
541
|
dim = dim,
|
@@ -545,10 +543,12 @@ class MemoryAsContextTransformer(Module):
|
|
545
543
|
**neural_memory_kwargs
|
546
544
|
)
|
547
545
|
|
546
|
+
|
548
547
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
549
548
|
|
550
549
|
self.layers.append(ModuleList([
|
551
|
-
|
550
|
+
mem_hyper_conn,
|
551
|
+
mem,
|
552
552
|
init_hyper_conn(dim = dim, branch = attn),
|
553
553
|
init_hyper_conn(dim = dim, branch = ff)
|
554
554
|
]))
|
@@ -691,8 +691,18 @@ class MemoryAsContextTransformer(Module):
|
|
691
691
|
# kv caching
|
692
692
|
|
693
693
|
is_inferencing = exists(cache)
|
694
|
-
|
694
|
+
assert not (is_inferencing and self.num_longterm_mem_tokens > 0)
|
695
|
+
|
696
|
+
if not exists(cache):
|
697
|
+
cache = (None, None)
|
698
|
+
|
699
|
+
kv_caches, neural_mem_caches = cache
|
700
|
+
|
701
|
+
kv_caches = iter(default(kv_caches, []))
|
702
|
+
neural_mem_caches = iter(default(neural_mem_caches, []))
|
703
|
+
|
695
704
|
next_kv_caches = []
|
705
|
+
next_neural_mem_caches = []
|
696
706
|
|
697
707
|
# value residual
|
698
708
|
|
@@ -711,21 +721,37 @@ class MemoryAsContextTransformer(Module):
|
|
711
721
|
|
712
722
|
x = self.expand_streams(x)
|
713
723
|
|
714
|
-
for mem, attn, ff in self.layers:
|
724
|
+
for mem_hyper_conn, mem, attn, ff in self.layers:
|
715
725
|
|
716
726
|
retrieved = None
|
717
727
|
attn_out_gates = None
|
728
|
+
next_neural_mem_cache = None
|
718
729
|
|
719
730
|
# maybe neural memory
|
720
731
|
|
721
732
|
if exists(mem):
|
722
|
-
|
723
|
-
|
733
|
+
|
734
|
+
mem_input, add_residual = mem_hyper_conn(x)
|
735
|
+
|
736
|
+
if not is_inferencing:
|
737
|
+
retrieved, mem_kv_aux_loss = mem(
|
738
|
+
mem_input,
|
739
|
+
return_aux_kv_loss = True
|
740
|
+
)
|
741
|
+
|
742
|
+
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
743
|
+
|
744
|
+
next_neural_mem_cache = (seq_len, None, None, None)
|
745
|
+
else:
|
746
|
+
retrieved, next_neural_mem_cache = mem.forward_inference(
|
747
|
+
mem_input,
|
748
|
+
state = next(neural_mem_caches, None)
|
749
|
+
)
|
724
750
|
|
725
751
|
if self.gate_attn_output:
|
726
752
|
attn_out_gates = retrieved.sigmoid()
|
727
753
|
else:
|
728
|
-
|
754
|
+
x = add_residual(retrieved)
|
729
755
|
|
730
756
|
# attention
|
731
757
|
|
@@ -735,12 +761,15 @@ class MemoryAsContextTransformer(Module):
|
|
735
761
|
disable_flex_attn = disable_flex_attn,
|
736
762
|
flex_attn_fn = flex_attn_fn,
|
737
763
|
output_gating = attn_out_gates,
|
738
|
-
cache = next(
|
764
|
+
cache = next(kv_caches, None)
|
739
765
|
)
|
740
766
|
|
741
767
|
value_residual = default(value_residual, values)
|
742
768
|
|
769
|
+
# caches
|
770
|
+
|
743
771
|
next_kv_caches.append(next_kv_cache)
|
772
|
+
next_neural_mem_caches.append(next_neural_mem_cache)
|
744
773
|
|
745
774
|
# feedforward
|
746
775
|
|
@@ -775,7 +804,7 @@ class MemoryAsContextTransformer(Module):
|
|
775
804
|
if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
|
776
805
|
next_kv_caches = next_kv_caches[..., 0:0, :]
|
777
806
|
|
778
|
-
return logits, next_kv_caches
|
807
|
+
return logits, (next_kv_caches, next_neural_mem_caches)
|
779
808
|
|
780
809
|
ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
781
810
|
|
titans_pytorch/titans.py
CHANGED
@@ -324,15 +324,26 @@ class AssocScan(Module):
|
|
324
324
|
super().__init__()
|
325
325
|
self.use_accelerated = use_accelerated
|
326
326
|
|
327
|
-
def forward(
|
327
|
+
def forward(
|
328
|
+
self,
|
329
|
+
gates,
|
330
|
+
inputs,
|
331
|
+
prev = None,
|
332
|
+
remove_prev = None
|
333
|
+
):
|
334
|
+
remove_prev = default(remove_prev, exists(prev))
|
328
335
|
|
329
336
|
if exists(prev):
|
330
337
|
inputs, _ = pack([prev, inputs], 'b * d')
|
331
338
|
gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
|
332
339
|
|
333
340
|
if not self.use_accelerated:
|
334
|
-
_,
|
335
|
-
|
341
|
+
_, out = associative_scan(binary_operator, (gates, inputs))
|
342
|
+
|
343
|
+
if remove_prev:
|
344
|
+
out = out[:, 1:]
|
345
|
+
|
346
|
+
return out
|
336
347
|
|
337
348
|
from accelerated_scan.triton import scan as triton_scan
|
338
349
|
from accelerated_scan.warp import scan as warp_scan
|
@@ -355,7 +366,12 @@ class AssocScan(Module):
|
|
355
366
|
outputs = rearrange(outputs, 'b d n -> b n d')
|
356
367
|
return outputs
|
357
368
|
|
358
|
-
|
369
|
+
out = accelerate_scan_fn(gates, inputs)
|
370
|
+
|
371
|
+
if remove_prev:
|
372
|
+
out = out[:, 1:]
|
373
|
+
|
374
|
+
return out
|
359
375
|
|
360
376
|
# main neural memory
|
361
377
|
|
@@ -384,7 +400,6 @@ class NeuralMemory(Module):
|
|
384
400
|
post_rmsnorm = True,
|
385
401
|
qk_rmsnorm = False,
|
386
402
|
accept_value_residual = False,
|
387
|
-
learned_mem_model_weights = True,
|
388
403
|
max_grad_norm: float | None = None,
|
389
404
|
use_accelerated_scan = False,
|
390
405
|
activation: Module | None = None,
|
@@ -432,9 +447,6 @@ class NeuralMemory(Module):
|
|
432
447
|
if not exists(model):
|
433
448
|
model = MemoryMLP(dim_head, **default_model_kwargs)
|
434
449
|
|
435
|
-
if not learned_mem_model_weights:
|
436
|
-
model.requires_grad_(False)
|
437
|
-
|
438
450
|
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
439
451
|
|
440
452
|
# the memory is the weights of the model
|
@@ -536,16 +548,9 @@ class NeuralMemory(Module):
|
|
536
548
|
|
537
549
|
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
538
550
|
|
539
|
-
def
|
540
|
-
|
541
|
-
|
542
|
-
init_weights = params
|
543
|
-
init_momentum = params.clone().zero_()
|
544
|
-
|
545
|
-
if zero_weights:
|
546
|
-
init_weights = params.clone().zero_()
|
547
|
-
|
548
|
-
return init_weights, init_momentum
|
551
|
+
def init_weights(self):
|
552
|
+
weights = TensorDict(dict(self.memory_model.named_parameters()))
|
553
|
+
return weights
|
549
554
|
|
550
555
|
def init_empty_memory_embed(self, batch, seq_len):
|
551
556
|
return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
|
@@ -553,7 +558,8 @@ class NeuralMemory(Module):
|
|
553
558
|
def store_memories(
|
554
559
|
self,
|
555
560
|
seq,
|
556
|
-
|
561
|
+
weights: dict[str, Tensor],
|
562
|
+
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
557
563
|
return_aux_kv_loss = False,
|
558
564
|
chunk_size = None,
|
559
565
|
value_residual = None
|
@@ -565,8 +571,7 @@ class NeuralMemory(Module):
|
|
565
571
|
# handle edge case
|
566
572
|
|
567
573
|
if seq_len < chunk_size:
|
568
|
-
|
569
|
-
return TensorDict(past_weight).clone().zero_(), self.zero
|
574
|
+
return TensorDict(weights).clone().zero_(), self.zero
|
570
575
|
|
571
576
|
seq = self.store_norm(seq)
|
572
577
|
|
@@ -577,10 +582,9 @@ class NeuralMemory(Module):
|
|
577
582
|
|
578
583
|
seq = seq[:, :round_down_seq_len]
|
579
584
|
|
580
|
-
#
|
585
|
+
# weights of the memory network
|
581
586
|
|
582
|
-
|
583
|
-
curr_weights, past_momentum = past_state
|
587
|
+
weights = TensorDict(weights)
|
584
588
|
|
585
589
|
# derive learned hparams for optimization of memory network
|
586
590
|
|
@@ -630,7 +634,7 @@ class NeuralMemory(Module):
|
|
630
634
|
|
631
635
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
632
636
|
|
633
|
-
grads, aux_kv_recon_loss = self.per_sample_grad_fn(dict(
|
637
|
+
grads, aux_kv_recon_loss = self.per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
|
634
638
|
|
635
639
|
grads = TensorDict(grads)
|
636
640
|
|
@@ -652,12 +656,23 @@ class NeuralMemory(Module):
|
|
652
656
|
|
653
657
|
surprises = grads.apply(lambda t: -t)
|
654
658
|
|
659
|
+
# past states
|
660
|
+
|
661
|
+
if not exists(past_state):
|
662
|
+
empty_dict = {key: None for key in weights.keys()}
|
663
|
+
past_state = (empty_dict, empty_dict)
|
664
|
+
|
665
|
+
past_last_update, past_last_momentum = past_state
|
666
|
+
|
655
667
|
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
656
668
|
|
657
669
|
next_momentum = TensorDict() if has_momentum else None
|
658
670
|
updates = TensorDict()
|
659
671
|
|
660
|
-
|
672
|
+
next_last_update = TensorDict()
|
673
|
+
next_last_momentum = TensorDict()
|
674
|
+
|
675
|
+
for (param_name, surprise), (_, last_update), (_, last_momentum) in zip(surprises.items(), past_last_update.items(), past_last_momentum.items()):
|
661
676
|
|
662
677
|
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
|
663
678
|
|
@@ -666,23 +681,27 @@ class NeuralMemory(Module):
|
|
666
681
|
# derive momentum with associative scan - eq (10)
|
667
682
|
|
668
683
|
if has_momentum:
|
669
|
-
update = self.assoc_scan(adaptive_momentum, surprise) # momentum is S / surprise in the paper
|
684
|
+
update = self.assoc_scan(adaptive_momentum, surprise, prev = last_momentum) # momentum is S / surprise in the paper
|
670
685
|
momentum = update
|
686
|
+
next_last_momentum[param_name] = momentum[:, -1]
|
671
687
|
|
672
688
|
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
673
689
|
|
674
|
-
update = self.assoc_scan(1. - decay_factor, update)
|
690
|
+
update = self.assoc_scan(1. - decay_factor, update, prev = last_update)
|
691
|
+
next_last_update[param_name] = update[:, -1]
|
675
692
|
|
676
693
|
updates[param_name] = inverse_pack(update)
|
677
694
|
|
678
695
|
if has_momentum:
|
679
696
|
next_momentum[param_name] = inverse_pack(momentum)
|
680
697
|
|
681
|
-
# compute
|
698
|
+
# compute next states for inference, or titans-xl like training
|
682
699
|
|
683
|
-
|
700
|
+
next_state = (next_last_update, next_last_momentum)
|
684
701
|
|
685
|
-
|
702
|
+
# returns
|
703
|
+
|
704
|
+
output = (updates, next_state, orig_values)
|
686
705
|
|
687
706
|
if not return_aux_kv_loss:
|
688
707
|
return output
|
@@ -764,21 +783,25 @@ class NeuralMemory(Module):
|
|
764
783
|
def forward_inference(
|
765
784
|
self,
|
766
785
|
token: Tensor,
|
767
|
-
|
768
|
-
mem_model_state = None,
|
769
|
-
cache_store_seq = None
|
786
|
+
state = None,
|
770
787
|
):
|
771
|
-
|
788
|
+
|
789
|
+
# unpack previous state
|
790
|
+
|
791
|
+
if not exists(state):
|
792
|
+
state = (0, None, None, None)
|
793
|
+
|
794
|
+
seq_index, cache_store_seq, past_states, updates = state
|
795
|
+
|
772
796
|
curr_seq_len = seq_index + 1
|
773
797
|
batch = token.shape[0]
|
774
798
|
|
775
799
|
if token.ndim == 2:
|
776
800
|
token = rearrange(token, 'b d -> b 1 d')
|
777
801
|
|
778
|
-
#
|
802
|
+
# get memory model weights
|
779
803
|
|
780
|
-
|
781
|
-
mem_model_state = self.init_weights_and_momentum()
|
804
|
+
weights = self.init_weights()
|
782
805
|
|
783
806
|
# increment the sequence cache which is at most the chunk size
|
784
807
|
|
@@ -789,32 +812,43 @@ class NeuralMemory(Module):
|
|
789
812
|
if curr_seq_len < self.chunk_size:
|
790
813
|
empty_mem = self.init_empty_memory_embed(batch, 1)
|
791
814
|
|
792
|
-
return empty_mem, cache_store_seq,
|
815
|
+
return empty_mem, (curr_seq_len, cache_store_seq, past_states, updates)
|
793
816
|
|
794
817
|
# store if storage sequence cache hits the chunk size
|
795
818
|
|
819
|
+
next_states = past_states
|
796
820
|
store_seq_cache_len = cache_store_seq.shape[-2]
|
797
821
|
|
822
|
+
if not exists(updates):
|
823
|
+
updates = weights.clone().zero_()
|
824
|
+
updates = updates.apply(lambda t: repeat(t, '... -> b 1 ...', b = batch))
|
825
|
+
|
798
826
|
if store_seq_cache_len == self.chunk_size:
|
799
|
-
updates, _ = self.store_memories(cache_store_seq, mem_model_state)
|
800
827
|
|
801
|
-
|
802
|
-
|
828
|
+
next_updates, next_states, _ = self.store_memories(
|
829
|
+
cache_store_seq,
|
830
|
+
weights,
|
831
|
+
past_state = past_states
|
832
|
+
)
|
803
833
|
|
834
|
+
updates = next_updates
|
804
835
|
cache_store_seq = None
|
805
836
|
|
806
837
|
# retrieve
|
807
838
|
|
808
|
-
|
839
|
+
retrieved = self.retrieve_memories(token, updates + weights, chunk_size = 1)
|
840
|
+
|
841
|
+
# next state tuple
|
809
842
|
|
810
|
-
|
843
|
+
next_state = (curr_seq_len, cache_store_seq, next_states, updates)
|
811
844
|
|
812
|
-
return retrieved,
|
845
|
+
return retrieved, next_state
|
813
846
|
|
814
847
|
def forward(
|
815
848
|
self,
|
816
849
|
seq,
|
817
850
|
store_seq = None,
|
851
|
+
mem_model_weights: dict[str, Tensor] | None = None,
|
818
852
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
819
853
|
return_aux_kv_loss = False,
|
820
854
|
chunk_size = None,
|
@@ -831,20 +865,15 @@ class NeuralMemory(Module):
|
|
831
865
|
|
832
866
|
return out, self.zero
|
833
867
|
|
834
|
-
if exists(
|
835
|
-
|
836
|
-
|
837
|
-
if not exists(past_state):
|
838
|
-
past_state = self.init_weights_and_momentum()
|
868
|
+
if not exists(mem_model_weights):
|
869
|
+
mem_model_weights = self.init_weights()
|
839
870
|
|
840
871
|
store_seq = default(store_seq, seq)
|
841
872
|
store_chunk_size = default(store_chunk_size, chunk_size)
|
842
873
|
|
843
|
-
(updates, values), aux_kv_recon_loss = self.store_memories(store_seq,
|
844
|
-
|
845
|
-
past_weights, _ = past_state
|
874
|
+
(updates, next_state, values), aux_kv_recon_loss = self.store_memories(store_seq, mem_model_weights, chunk_size = store_chunk_size, return_aux_kv_loss = True)
|
846
875
|
|
847
|
-
retrieved = self.retrieve_memories(seq,
|
876
|
+
retrieved = self.retrieve_memories(seq, mem_model_weights + updates, chunk_size = chunk_size)
|
848
877
|
|
849
878
|
output = retrieved
|
850
879
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.27
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -35,7 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
36
|
Requires-Python: >=3.9
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
38
|
-
Requires-Dist: axial-positional-embedding>=0.3.
|
38
|
+
Requires-Dist: axial-positional-embedding>=0.3.10
|
39
39
|
Requires-Dist: einops>=0.8.0
|
40
40
|
Requires-Dist: einx>=0.3.0
|
41
41
|
Requires-Dist: hyper-connections>=0.1.9
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=Staf9hRQ44QAL23bSGh4VSB8NeGtMri-JdiZdgirJiU,23587
|
4
|
+
titans_pytorch/titans.py,sha256=gjoDcTsvw5X2d1I2xq4cM45YJIBqtLFuws8_jVylW_4,25746
|
5
|
+
titans_pytorch-0.1.27.dist-info/METADATA,sha256=AZ5-_d9o_khm6jaky1zoKyXB1hDQNifbS061v_b4McQ,6815
|
6
|
+
titans_pytorch-0.1.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.1.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.1.27.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
|
4
|
-
titans_pytorch/titans.py,sha256=WbagKMYDs-3NoW2j_pAyHEnvR9QzH3A9WntHuV_FKOo,25109
|
5
|
-
titans_pytorch-0.1.23.dist-info/METADATA,sha256=H7QbLscawNObHGeoTbnKbf-NOqkMqWCu4yWeZJ0yKMA,6814
|
6
|
-
titans_pytorch-0.1.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.23.dist-info/RECORD,,
|
File without changes
|
File without changes
|