titans-pytorch 0.2.1__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +2 -19
- titans_pytorch/neural_memory.py +42 -89
- {titans_pytorch-0.2.1.dist-info → titans_pytorch-0.2.4.dist-info}/METADATA +1 -1
- titans_pytorch-0.2.4.dist-info/RECORD +9 -0
- titans_pytorch-0.2.1.dist-info/RECORD +0 -9
- {titans_pytorch-0.2.1.dist-info → titans_pytorch-0.2.4.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.2.1.dist-info → titans_pytorch-0.2.4.dist-info}/licenses/LICENSE +0 -0
@@ -536,11 +536,6 @@ class MemoryAsContextTransformer(Module):
|
|
536
536
|
self.weight_tie_memory_model = weight_tie_memory_model
|
537
537
|
self.prev_neural_mem_update_for_weights = default(prev_neural_mem_update_for_weights, weight_tie_memory_model)
|
538
538
|
|
539
|
-
# value residual learning for neural memory
|
540
|
-
|
541
|
-
is_first_mem = True
|
542
|
-
self.mem_add_value_residual = neural_memory_add_value_residual
|
543
|
-
|
544
539
|
# mem, attn, and feedforward layers
|
545
540
|
|
546
541
|
for layer in layers:
|
@@ -570,12 +565,9 @@ class MemoryAsContextTransformer(Module):
|
|
570
565
|
dim = dim,
|
571
566
|
chunk_size = self.neural_memory_segment_len,
|
572
567
|
model = maybe_copy(neural_memory_model),
|
573
|
-
accept_value_residual = not is_first_mem and neural_memory_add_value_residual,
|
574
568
|
**neural_memory_kwargs
|
575
569
|
)
|
576
570
|
|
577
|
-
is_first_mem = False
|
578
|
-
|
579
571
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
580
572
|
|
581
573
|
self.layers.append(ModuleList([
|
@@ -765,8 +757,6 @@ class MemoryAsContextTransformer(Module):
|
|
765
757
|
|
766
758
|
value_residual = None
|
767
759
|
|
768
|
-
mem_value_residual = None
|
769
|
-
|
770
760
|
# aux losses
|
771
761
|
|
772
762
|
kv_recon_losses = self.zero
|
@@ -794,28 +784,21 @@ class MemoryAsContextTransformer(Module):
|
|
794
784
|
mem_input, add_residual = mem_hyper_conn(x)
|
795
785
|
|
796
786
|
if not is_inferencing:
|
797
|
-
(retrieved, next_neural_mem_cache
|
787
|
+
(retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
|
798
788
|
mem_input,
|
799
789
|
return_aux_kv_loss = True,
|
800
|
-
return_values = True,
|
801
|
-
value_residual = mem_value_residual,
|
802
790
|
prev_layer_updates = neural_memory_updates
|
803
791
|
)
|
804
792
|
|
805
793
|
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
806
794
|
|
807
795
|
else:
|
808
|
-
(retrieved, next_neural_mem_cache
|
796
|
+
(retrieved, next_neural_mem_cache) = mem.forward_inference(
|
809
797
|
mem_input,
|
810
798
|
state = next(neural_mem_caches, None),
|
811
|
-
return_values = True,
|
812
|
-
value_residual = mem_value_residual,
|
813
799
|
prev_layer_updates = neural_memory_updates
|
814
800
|
)
|
815
801
|
|
816
|
-
if self.mem_add_value_residual:
|
817
|
-
mem_value_residual = next_mem_value_residual
|
818
|
-
|
819
802
|
if prev_neural_mem_update_for_weights:
|
820
803
|
neural_memory_updates = next_neural_mem_cache.updates
|
821
804
|
|
titans_pytorch/neural_memory.py
CHANGED
@@ -67,6 +67,9 @@ def safe_cat(inputs, dim = -2):
|
|
67
67
|
def identity(t):
|
68
68
|
return t
|
69
69
|
|
70
|
+
def dict_get_shape(td):
|
71
|
+
return {k: v.shape for k, v in td.items()}
|
72
|
+
|
70
73
|
def pair(v):
|
71
74
|
return (v, v) if not isinstance(v, tuple) else v
|
72
75
|
|
@@ -258,7 +261,6 @@ class NeuralMemory(Module):
|
|
258
261
|
pre_rmsnorm = True,
|
259
262
|
post_rmsnorm = True,
|
260
263
|
qk_rmsnorm = False,
|
261
|
-
accept_value_residual = False,
|
262
264
|
max_grad_norm: float | None = None,
|
263
265
|
use_accelerated_scan = False,
|
264
266
|
activation: Module | None = None,
|
@@ -315,6 +317,8 @@ class NeuralMemory(Module):
|
|
315
317
|
|
316
318
|
self.num_memory_parameter_tensors = len(set(model.parameters()))
|
317
319
|
|
320
|
+
self.init_weight_shape = dict_get_shape(dict(model.named_parameters()))
|
321
|
+
|
318
322
|
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
319
323
|
|
320
324
|
self.chunk_size = chunk_size
|
@@ -343,19 +347,6 @@ class NeuralMemory(Module):
|
|
343
347
|
self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
|
344
348
|
self.store_memory_loss_fn = store_memory_loss_fn
|
345
349
|
|
346
|
-
# value residual learning
|
347
|
-
|
348
|
-
self.learned_value_residual = Sequential(
|
349
|
-
LinearNoBias(dim, heads),
|
350
|
-
Rearrange('b n h -> b h n 1'),
|
351
|
-
nn.Sigmoid()
|
352
|
-
) if accept_value_residual else None
|
353
|
-
|
354
|
-
# empty memory embed
|
355
|
-
|
356
|
-
self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
|
357
|
-
nn.init.normal_(self.empty_memory_embed, std = 0.02)
|
358
|
-
|
359
350
|
# `chunk_size` refers to chunk size used for storing to memory model weights
|
360
351
|
|
361
352
|
chunk_size = self.store_chunk_size
|
@@ -417,9 +408,6 @@ class NeuralMemory(Module):
|
|
417
408
|
weights = TensorDict(dict(self.memory_model.named_parameters()))
|
418
409
|
return weights
|
419
410
|
|
420
|
-
def init_empty_memory_embed(self, batch, seq_len):
|
421
|
-
return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
|
422
|
-
|
423
411
|
def store_memories(
|
424
412
|
self,
|
425
413
|
seq,
|
@@ -428,10 +416,7 @@ class NeuralMemory(Module):
|
|
428
416
|
prev_layer_updates: dict[str, Tensor] | None = None,
|
429
417
|
return_aux_kv_loss = False,
|
430
418
|
chunk_size = None,
|
431
|
-
value_residual = None
|
432
419
|
):
|
433
|
-
assert xnor(exists(value_residual), exists(self.learned_value_residual))
|
434
|
-
|
435
420
|
seq_len, heads, chunk_size = seq.shape[-2], self.heads, default(chunk_size, self.store_chunk_size)
|
436
421
|
|
437
422
|
# handle edge case
|
@@ -446,7 +431,7 @@ class NeuralMemory(Module):
|
|
446
431
|
|
447
432
|
round_down_seq_len = round_down_multiple(seq_len, chunk_size)
|
448
433
|
|
449
|
-
seq = seq[:, :round_down_seq_len]
|
434
|
+
seq, remainder = seq[:, :round_down_seq_len], seq[:, round_down_seq_len:]
|
450
435
|
|
451
436
|
# per sample grad function
|
452
437
|
|
@@ -499,14 +484,6 @@ class NeuralMemory(Module):
|
|
499
484
|
|
500
485
|
keys = self.k_norm(keys)
|
501
486
|
|
502
|
-
# maybe value residual learning
|
503
|
-
|
504
|
-
orig_values = values
|
505
|
-
|
506
|
-
if exists(self.learned_value_residual):
|
507
|
-
mix = self.learned_value_residual(seq)
|
508
|
-
values = values.lerp(value_residual, mix)
|
509
|
-
|
510
487
|
# take care of chunking
|
511
488
|
|
512
489
|
keys, values = tuple(rearrange(t, 'b h (n c) d -> (b h n) c d', c = chunk_size) for t in (keys, values))
|
@@ -581,13 +558,15 @@ class NeuralMemory(Module):
|
|
581
558
|
if has_momentum:
|
582
559
|
next_momentum[param_name] = inverse_pack(momentum)
|
583
560
|
|
584
|
-
#
|
561
|
+
# determine next state for the storing of memories
|
585
562
|
|
586
563
|
next_state = (next_last_update, next_last_momentum)
|
587
564
|
|
565
|
+
next_store_state = NeuralMemCache(seq_len, remainder, next_state, updates)
|
566
|
+
|
588
567
|
# returns
|
589
568
|
|
590
|
-
output = (updates,
|
569
|
+
output = (updates, next_store_state)
|
591
570
|
|
592
571
|
if not return_aux_kv_loss:
|
593
572
|
return output
|
@@ -606,16 +585,18 @@ class NeuralMemory(Module):
|
|
606
585
|
|
607
586
|
seq = self.retrieve_norm(seq)
|
608
587
|
|
609
|
-
|
610
|
-
|
588
|
+
assert seq_len >= chunk_size, 'must be handled outside of retrieve'
|
589
|
+
|
590
|
+
needs_pad = chunk_size > 1
|
611
591
|
|
612
|
-
|
613
|
-
|
592
|
+
if needs_pad:
|
593
|
+
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
594
|
+
seq_len_plus_one = seq.shape[-2]
|
614
595
|
|
615
|
-
|
596
|
+
next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
|
616
597
|
|
617
|
-
|
618
|
-
|
598
|
+
padding = next_seq_len - seq_len_plus_one
|
599
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
619
600
|
|
620
601
|
# the parameters of the memory model stores the memories of the key / values
|
621
602
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
@@ -639,7 +620,9 @@ class NeuralMemory(Module):
|
|
639
620
|
|
640
621
|
# fetch values from memory model
|
641
622
|
|
642
|
-
curr_weights
|
623
|
+
if dict_get_shape(curr_weights) != self.init_weight_shape:
|
624
|
+
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
625
|
+
|
643
626
|
queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
|
644
627
|
|
645
628
|
# forward functional call
|
@@ -665,10 +648,10 @@ class NeuralMemory(Module):
|
|
665
648
|
|
666
649
|
# restore, pad with empty memory embed
|
667
650
|
|
668
|
-
|
669
|
-
|
651
|
+
if needs_pad:
|
652
|
+
values = values[:, 1:(seq_len + 1)]
|
670
653
|
|
671
|
-
return values
|
654
|
+
return values
|
672
655
|
|
673
656
|
@torch.no_grad()
|
674
657
|
def forward_inference(
|
@@ -676,8 +659,6 @@ class NeuralMemory(Module):
|
|
676
659
|
token: Tensor,
|
677
660
|
state = None,
|
678
661
|
prev_layer_updates: dict[str, Tensor] | None = None,
|
679
|
-
return_values = False,
|
680
|
-
value_residual = None,
|
681
662
|
):
|
682
663
|
|
683
664
|
# unpack previous state
|
@@ -704,12 +685,9 @@ class NeuralMemory(Module):
|
|
704
685
|
# early return empty memory, when no memories are stored for steps < first chunk size
|
705
686
|
|
706
687
|
if curr_seq_len < self.chunk_size:
|
707
|
-
|
708
|
-
|
709
|
-
output = empty_mem, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
|
688
|
+
retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
|
710
689
|
|
711
|
-
|
712
|
-
output = (*output, self.zero)
|
690
|
+
output = retrieve, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
|
713
691
|
|
714
692
|
return output
|
715
693
|
|
@@ -728,20 +706,18 @@ class NeuralMemory(Module):
|
|
728
706
|
prev_layer_updates = TensorDict(prev_layer_updates)
|
729
707
|
prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
|
730
708
|
|
731
|
-
values = None
|
732
|
-
|
733
709
|
if store_seq_cache_len == self.chunk_size:
|
734
710
|
|
735
|
-
next_updates,
|
711
|
+
next_updates, store_state = self.store_memories(
|
736
712
|
cache_store_seq,
|
737
713
|
weights,
|
738
714
|
past_state = past_states,
|
739
715
|
prev_layer_updates = prev_layer_updates,
|
740
|
-
value_residual = value_residual
|
741
716
|
)
|
742
717
|
|
743
718
|
updates = next_updates
|
744
719
|
cache_store_seq = None
|
720
|
+
next_states = store_state.states
|
745
721
|
|
746
722
|
# retrieve
|
747
723
|
|
@@ -749,14 +725,9 @@ class NeuralMemory(Module):
|
|
749
725
|
|
750
726
|
# next state tuple
|
751
727
|
|
752
|
-
|
753
|
-
|
754
|
-
output = (retrieved, next_state)
|
755
|
-
|
756
|
-
if return_values:
|
757
|
-
output = (*output, values)
|
728
|
+
next_store_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
|
758
729
|
|
759
|
-
return
|
730
|
+
return retrieved, next_store_state
|
760
731
|
|
761
732
|
def forward(
|
762
733
|
self,
|
@@ -767,50 +738,45 @@ class NeuralMemory(Module):
|
|
767
738
|
return_aux_kv_loss = False,
|
768
739
|
chunk_size = None,
|
769
740
|
store_chunk_size = None,
|
770
|
-
return_values = False,
|
771
|
-
value_residual = None,
|
772
741
|
return_next_state = False,
|
773
742
|
prev_layer_updates: dict[str, Tensor] | None = None
|
774
743
|
):
|
775
744
|
batch, seq_len = seq.shape[:2]
|
776
745
|
|
746
|
+
if not exists(mem_model_weights):
|
747
|
+
mem_model_weights = self.init_weights()
|
748
|
+
|
777
749
|
if seq_len < self.retrieve_chunk_size:
|
778
|
-
|
750
|
+
retrieved = self.retrieve_memories(seq, mem_model_weights, chunk_size = 1)
|
779
751
|
|
780
752
|
next_store_state = NeuralMemCache(seq_len, seq, None, None)
|
781
753
|
|
782
|
-
out = (
|
783
|
-
|
784
|
-
if return_values:
|
785
|
-
out = (*out, self.zero)
|
754
|
+
out = (retrieved, next_store_state)
|
786
755
|
|
787
756
|
if not return_aux_kv_loss:
|
788
757
|
return out
|
789
758
|
|
790
759
|
return out, self.zero
|
791
760
|
|
792
|
-
if not exists(mem_model_weights):
|
793
|
-
mem_model_weights = self.init_weights()
|
794
|
-
|
795
761
|
# store
|
796
762
|
|
797
763
|
store_seq = default(store_seq, seq)
|
798
764
|
|
799
|
-
|
800
|
-
store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
|
801
|
-
remainder = store_seq_len % store_chunk_size
|
802
|
-
|
803
|
-
(updates, next_state, values), aux_kv_recon_loss = self.store_memories(
|
765
|
+
(updates, next_store_state), aux_kv_recon_loss = self.store_memories(
|
804
766
|
store_seq,
|
805
767
|
mem_model_weights,
|
806
768
|
chunk_size = store_chunk_size,
|
807
769
|
prev_layer_updates = prev_layer_updates,
|
808
|
-
value_residual = value_residual,
|
809
770
|
return_aux_kv_loss = True
|
810
771
|
)
|
811
772
|
|
812
773
|
# retrieve
|
813
774
|
|
775
|
+
if exists(prev_layer_updates):
|
776
|
+
prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
777
|
+
|
778
|
+
updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
779
|
+
|
814
780
|
retrieved = self.retrieve_memories(
|
815
781
|
seq,
|
816
782
|
mem_model_weights + updates,
|
@@ -818,21 +784,8 @@ class NeuralMemory(Module):
|
|
818
784
|
prev_layer_updates = prev_layer_updates
|
819
785
|
)
|
820
786
|
|
821
|
-
# determine state for the storing of memories
|
822
|
-
# for transformer-xl like training with neural memory as well as inferencing with initial prompt
|
823
|
-
|
824
|
-
cache_store_seq = None
|
825
|
-
|
826
|
-
if remainder > 0:
|
827
|
-
cache_store_seq = store_seq[:, -remainder:]
|
828
|
-
|
829
|
-
next_store_state = NeuralMemCache(seq_len, cache_store_seq, next_state, updates)
|
830
|
-
|
831
787
|
output = (retrieved, next_store_state)
|
832
788
|
|
833
|
-
if return_values:
|
834
|
-
output = (*output, values)
|
835
|
-
|
836
789
|
if not return_aux_kv_loss:
|
837
790
|
return output
|
838
791
|
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=g-Rx8zwTUbMv-XBYWPe9abFVVSUFLxOn_yVQ-wWvG5M,26039
|
4
|
+
titans_pytorch/memory_models.py,sha256=LI9T36XB6YXIvvGWRw0ZMDlGpRC6KIv03OPzME2VAaU,3772
|
5
|
+
titans_pytorch/neural_memory.py,sha256=3ykFukUDp3dW1QwDmS3jZ2wFysiZE2ippcOoMFall34,24143
|
6
|
+
titans_pytorch-0.2.4.dist-info/METADATA,sha256=2yY3d58zPQ1uyvnTX4Dml7a2dd2jRu3TR5NhBpPNmdY,6819
|
7
|
+
titans_pytorch-0.2.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.2.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.2.4.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=kqW90mpbFf1ZJ_mMkd6v9EQ5J__TwKMPy5cjHJF_26A,26742
|
4
|
-
titans_pytorch/memory_models.py,sha256=LI9T36XB6YXIvvGWRw0ZMDlGpRC6KIv03OPzME2VAaU,3772
|
5
|
-
titans_pytorch/neural_memory.py,sha256=vmKPOAlXBPXBnYPODrg_reWaIcr1xwtfQmuptGS6e5A,25559
|
6
|
-
titans_pytorch-0.2.1.dist-info/METADATA,sha256=HPdcQb4SlT-eLFzOYLMwGInEKegL4M4yIpKWt1a6DTs,6819
|
7
|
-
titans_pytorch-0.2.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.2.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.2.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|