titans-pytorch 0.2.0__tar.gz → 0.2.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/PKG-INFO +1 -1
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/pyproject.toml +1 -1
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/titans_pytorch/mac_transformer.py +7 -22
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/titans_pytorch/neural_memory.py +42 -89
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/train_mac.py +1 -3
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/.gitignore +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/LICENSE +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/README.md +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/data/README.md +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/data/enwik8.gz +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/fig1.png +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/fig2.png +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/tests/test_titans.py +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.2.0 → titans_pytorch-0.2.4}/titans_pytorch/memory_models.py +0 -0
@@ -491,7 +491,8 @@ class MemoryAsContextTransformer(Module):
|
|
491
491
|
aux_kv_recon_loss_weight = 0.,
|
492
492
|
use_flex_attn = False,
|
493
493
|
sliding_window_attn = False,
|
494
|
-
weight_tie_memory_model = False
|
494
|
+
weight_tie_memory_model = False,
|
495
|
+
prev_neural_mem_update_for_weights = None
|
495
496
|
):
|
496
497
|
super().__init__()
|
497
498
|
|
@@ -533,11 +534,7 @@ class MemoryAsContextTransformer(Module):
|
|
533
534
|
assert exists(neural_memory_model), '`neural_memory_model` must be explicitly set'
|
534
535
|
|
535
536
|
self.weight_tie_memory_model = weight_tie_memory_model
|
536
|
-
|
537
|
-
# value residual learning for neural memory
|
538
|
-
|
539
|
-
is_first_mem = True
|
540
|
-
self.mem_add_value_residual = neural_memory_add_value_residual
|
537
|
+
self.prev_neural_mem_update_for_weights = default(prev_neural_mem_update_for_weights, weight_tie_memory_model)
|
541
538
|
|
542
539
|
# mem, attn, and feedforward layers
|
543
540
|
|
@@ -568,12 +565,9 @@ class MemoryAsContextTransformer(Module):
|
|
568
565
|
dim = dim,
|
569
566
|
chunk_size = self.neural_memory_segment_len,
|
570
567
|
model = maybe_copy(neural_memory_model),
|
571
|
-
accept_value_residual = not is_first_mem and neural_memory_add_value_residual,
|
572
568
|
**neural_memory_kwargs
|
573
569
|
)
|
574
570
|
|
575
|
-
is_first_mem = False
|
576
|
-
|
577
571
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
578
572
|
|
579
573
|
self.layers.append(ModuleList([
|
@@ -702,7 +696,7 @@ class MemoryAsContextTransformer(Module):
|
|
702
696
|
|
703
697
|
# math
|
704
698
|
|
705
|
-
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size,
|
699
|
+
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, prev_neural_mem_update_for_weights = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.prev_neural_mem_update_for_weights
|
706
700
|
|
707
701
|
seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
|
708
702
|
|
@@ -763,8 +757,6 @@ class MemoryAsContextTransformer(Module):
|
|
763
757
|
|
764
758
|
value_residual = None
|
765
759
|
|
766
|
-
mem_value_residual = None
|
767
|
-
|
768
760
|
# aux losses
|
769
761
|
|
770
762
|
kv_recon_losses = self.zero
|
@@ -792,29 +784,22 @@ class MemoryAsContextTransformer(Module):
|
|
792
784
|
mem_input, add_residual = mem_hyper_conn(x)
|
793
785
|
|
794
786
|
if not is_inferencing:
|
795
|
-
(retrieved, next_neural_mem_cache
|
787
|
+
(retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
|
796
788
|
mem_input,
|
797
789
|
return_aux_kv_loss = True,
|
798
|
-
return_values = True,
|
799
|
-
value_residual = mem_value_residual,
|
800
790
|
prev_layer_updates = neural_memory_updates
|
801
791
|
)
|
802
792
|
|
803
793
|
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
804
794
|
|
805
795
|
else:
|
806
|
-
(retrieved, next_neural_mem_cache
|
796
|
+
(retrieved, next_neural_mem_cache) = mem.forward_inference(
|
807
797
|
mem_input,
|
808
798
|
state = next(neural_mem_caches, None),
|
809
|
-
return_values = True,
|
810
|
-
value_residual = mem_value_residual,
|
811
799
|
prev_layer_updates = neural_memory_updates
|
812
800
|
)
|
813
801
|
|
814
|
-
if
|
815
|
-
mem_value_residual = next_mem_value_residual
|
816
|
-
|
817
|
-
if weight_tie_memory_model:
|
802
|
+
if prev_neural_mem_update_for_weights:
|
818
803
|
neural_memory_updates = next_neural_mem_cache.updates
|
819
804
|
|
820
805
|
if self.gate_attn_output:
|
@@ -67,6 +67,9 @@ def safe_cat(inputs, dim = -2):
|
|
67
67
|
def identity(t):
|
68
68
|
return t
|
69
69
|
|
70
|
+
def dict_get_shape(td):
|
71
|
+
return {k: v.shape for k, v in td.items()}
|
72
|
+
|
70
73
|
def pair(v):
|
71
74
|
return (v, v) if not isinstance(v, tuple) else v
|
72
75
|
|
@@ -258,7 +261,6 @@ class NeuralMemory(Module):
|
|
258
261
|
pre_rmsnorm = True,
|
259
262
|
post_rmsnorm = True,
|
260
263
|
qk_rmsnorm = False,
|
261
|
-
accept_value_residual = False,
|
262
264
|
max_grad_norm: float | None = None,
|
263
265
|
use_accelerated_scan = False,
|
264
266
|
activation: Module | None = None,
|
@@ -315,6 +317,8 @@ class NeuralMemory(Module):
|
|
315
317
|
|
316
318
|
self.num_memory_parameter_tensors = len(set(model.parameters()))
|
317
319
|
|
320
|
+
self.init_weight_shape = dict_get_shape(dict(model.named_parameters()))
|
321
|
+
|
318
322
|
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
319
323
|
|
320
324
|
self.chunk_size = chunk_size
|
@@ -343,19 +347,6 @@ class NeuralMemory(Module):
|
|
343
347
|
self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
|
344
348
|
self.store_memory_loss_fn = store_memory_loss_fn
|
345
349
|
|
346
|
-
# value residual learning
|
347
|
-
|
348
|
-
self.learned_value_residual = Sequential(
|
349
|
-
LinearNoBias(dim, heads),
|
350
|
-
Rearrange('b n h -> b h n 1'),
|
351
|
-
nn.Sigmoid()
|
352
|
-
) if accept_value_residual else None
|
353
|
-
|
354
|
-
# empty memory embed
|
355
|
-
|
356
|
-
self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
|
357
|
-
nn.init.normal_(self.empty_memory_embed, std = 0.02)
|
358
|
-
|
359
350
|
# `chunk_size` refers to chunk size used for storing to memory model weights
|
360
351
|
|
361
352
|
chunk_size = self.store_chunk_size
|
@@ -417,9 +408,6 @@ class NeuralMemory(Module):
|
|
417
408
|
weights = TensorDict(dict(self.memory_model.named_parameters()))
|
418
409
|
return weights
|
419
410
|
|
420
|
-
def init_empty_memory_embed(self, batch, seq_len):
|
421
|
-
return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
|
422
|
-
|
423
411
|
def store_memories(
|
424
412
|
self,
|
425
413
|
seq,
|
@@ -428,10 +416,7 @@ class NeuralMemory(Module):
|
|
428
416
|
prev_layer_updates: dict[str, Tensor] | None = None,
|
429
417
|
return_aux_kv_loss = False,
|
430
418
|
chunk_size = None,
|
431
|
-
value_residual = None
|
432
419
|
):
|
433
|
-
assert xnor(exists(value_residual), exists(self.learned_value_residual))
|
434
|
-
|
435
420
|
seq_len, heads, chunk_size = seq.shape[-2], self.heads, default(chunk_size, self.store_chunk_size)
|
436
421
|
|
437
422
|
# handle edge case
|
@@ -446,7 +431,7 @@ class NeuralMemory(Module):
|
|
446
431
|
|
447
432
|
round_down_seq_len = round_down_multiple(seq_len, chunk_size)
|
448
433
|
|
449
|
-
seq = seq[:, :round_down_seq_len]
|
434
|
+
seq, remainder = seq[:, :round_down_seq_len], seq[:, round_down_seq_len:]
|
450
435
|
|
451
436
|
# per sample grad function
|
452
437
|
|
@@ -499,14 +484,6 @@ class NeuralMemory(Module):
|
|
499
484
|
|
500
485
|
keys = self.k_norm(keys)
|
501
486
|
|
502
|
-
# maybe value residual learning
|
503
|
-
|
504
|
-
orig_values = values
|
505
|
-
|
506
|
-
if exists(self.learned_value_residual):
|
507
|
-
mix = self.learned_value_residual(seq)
|
508
|
-
values = values.lerp(value_residual, mix)
|
509
|
-
|
510
487
|
# take care of chunking
|
511
488
|
|
512
489
|
keys, values = tuple(rearrange(t, 'b h (n c) d -> (b h n) c d', c = chunk_size) for t in (keys, values))
|
@@ -581,13 +558,15 @@ class NeuralMemory(Module):
|
|
581
558
|
if has_momentum:
|
582
559
|
next_momentum[param_name] = inverse_pack(momentum)
|
583
560
|
|
584
|
-
#
|
561
|
+
# determine next state for the storing of memories
|
585
562
|
|
586
563
|
next_state = (next_last_update, next_last_momentum)
|
587
564
|
|
565
|
+
next_store_state = NeuralMemCache(seq_len, remainder, next_state, updates)
|
566
|
+
|
588
567
|
# returns
|
589
568
|
|
590
|
-
output = (updates,
|
569
|
+
output = (updates, next_store_state)
|
591
570
|
|
592
571
|
if not return_aux_kv_loss:
|
593
572
|
return output
|
@@ -606,16 +585,18 @@ class NeuralMemory(Module):
|
|
606
585
|
|
607
586
|
seq = self.retrieve_norm(seq)
|
608
587
|
|
609
|
-
|
610
|
-
|
588
|
+
assert seq_len >= chunk_size, 'must be handled outside of retrieve'
|
589
|
+
|
590
|
+
needs_pad = chunk_size > 1
|
611
591
|
|
612
|
-
|
613
|
-
|
592
|
+
if needs_pad:
|
593
|
+
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
594
|
+
seq_len_plus_one = seq.shape[-2]
|
614
595
|
|
615
|
-
|
596
|
+
next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
|
616
597
|
|
617
|
-
|
618
|
-
|
598
|
+
padding = next_seq_len - seq_len_plus_one
|
599
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
619
600
|
|
620
601
|
# the parameters of the memory model stores the memories of the key / values
|
621
602
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
@@ -639,7 +620,9 @@ class NeuralMemory(Module):
|
|
639
620
|
|
640
621
|
# fetch values from memory model
|
641
622
|
|
642
|
-
curr_weights
|
623
|
+
if dict_get_shape(curr_weights) != self.init_weight_shape:
|
624
|
+
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
625
|
+
|
643
626
|
queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
|
644
627
|
|
645
628
|
# forward functional call
|
@@ -665,10 +648,10 @@ class NeuralMemory(Module):
|
|
665
648
|
|
666
649
|
# restore, pad with empty memory embed
|
667
650
|
|
668
|
-
|
669
|
-
|
651
|
+
if needs_pad:
|
652
|
+
values = values[:, 1:(seq_len + 1)]
|
670
653
|
|
671
|
-
return values
|
654
|
+
return values
|
672
655
|
|
673
656
|
@torch.no_grad()
|
674
657
|
def forward_inference(
|
@@ -676,8 +659,6 @@ class NeuralMemory(Module):
|
|
676
659
|
token: Tensor,
|
677
660
|
state = None,
|
678
661
|
prev_layer_updates: dict[str, Tensor] | None = None,
|
679
|
-
return_values = False,
|
680
|
-
value_residual = None,
|
681
662
|
):
|
682
663
|
|
683
664
|
# unpack previous state
|
@@ -704,12 +685,9 @@ class NeuralMemory(Module):
|
|
704
685
|
# early return empty memory, when no memories are stored for steps < first chunk size
|
705
686
|
|
706
687
|
if curr_seq_len < self.chunk_size:
|
707
|
-
|
708
|
-
|
709
|
-
output = empty_mem, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
|
688
|
+
retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
|
710
689
|
|
711
|
-
|
712
|
-
output = (*output, self.zero)
|
690
|
+
output = retrieve, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
|
713
691
|
|
714
692
|
return output
|
715
693
|
|
@@ -728,20 +706,18 @@ class NeuralMemory(Module):
|
|
728
706
|
prev_layer_updates = TensorDict(prev_layer_updates)
|
729
707
|
prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
|
730
708
|
|
731
|
-
values = None
|
732
|
-
|
733
709
|
if store_seq_cache_len == self.chunk_size:
|
734
710
|
|
735
|
-
next_updates,
|
711
|
+
next_updates, store_state = self.store_memories(
|
736
712
|
cache_store_seq,
|
737
713
|
weights,
|
738
714
|
past_state = past_states,
|
739
715
|
prev_layer_updates = prev_layer_updates,
|
740
|
-
value_residual = value_residual
|
741
716
|
)
|
742
717
|
|
743
718
|
updates = next_updates
|
744
719
|
cache_store_seq = None
|
720
|
+
next_states = store_state.states
|
745
721
|
|
746
722
|
# retrieve
|
747
723
|
|
@@ -749,14 +725,9 @@ class NeuralMemory(Module):
|
|
749
725
|
|
750
726
|
# next state tuple
|
751
727
|
|
752
|
-
|
753
|
-
|
754
|
-
output = (retrieved, next_state)
|
755
|
-
|
756
|
-
if return_values:
|
757
|
-
output = (*output, values)
|
728
|
+
next_store_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
|
758
729
|
|
759
|
-
return
|
730
|
+
return retrieved, next_store_state
|
760
731
|
|
761
732
|
def forward(
|
762
733
|
self,
|
@@ -767,50 +738,45 @@ class NeuralMemory(Module):
|
|
767
738
|
return_aux_kv_loss = False,
|
768
739
|
chunk_size = None,
|
769
740
|
store_chunk_size = None,
|
770
|
-
return_values = False,
|
771
|
-
value_residual = None,
|
772
741
|
return_next_state = False,
|
773
742
|
prev_layer_updates: dict[str, Tensor] | None = None
|
774
743
|
):
|
775
744
|
batch, seq_len = seq.shape[:2]
|
776
745
|
|
746
|
+
if not exists(mem_model_weights):
|
747
|
+
mem_model_weights = self.init_weights()
|
748
|
+
|
777
749
|
if seq_len < self.retrieve_chunk_size:
|
778
|
-
|
750
|
+
retrieved = self.retrieve_memories(seq, mem_model_weights, chunk_size = 1)
|
779
751
|
|
780
752
|
next_store_state = NeuralMemCache(seq_len, seq, None, None)
|
781
753
|
|
782
|
-
out = (
|
783
|
-
|
784
|
-
if return_values:
|
785
|
-
out = (*out, self.zero)
|
754
|
+
out = (retrieved, next_store_state)
|
786
755
|
|
787
756
|
if not return_aux_kv_loss:
|
788
757
|
return out
|
789
758
|
|
790
759
|
return out, self.zero
|
791
760
|
|
792
|
-
if not exists(mem_model_weights):
|
793
|
-
mem_model_weights = self.init_weights()
|
794
|
-
|
795
761
|
# store
|
796
762
|
|
797
763
|
store_seq = default(store_seq, seq)
|
798
764
|
|
799
|
-
|
800
|
-
store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
|
801
|
-
remainder = store_seq_len % store_chunk_size
|
802
|
-
|
803
|
-
(updates, next_state, values), aux_kv_recon_loss = self.store_memories(
|
765
|
+
(updates, next_store_state), aux_kv_recon_loss = self.store_memories(
|
804
766
|
store_seq,
|
805
767
|
mem_model_weights,
|
806
768
|
chunk_size = store_chunk_size,
|
807
769
|
prev_layer_updates = prev_layer_updates,
|
808
|
-
value_residual = value_residual,
|
809
770
|
return_aux_kv_loss = True
|
810
771
|
)
|
811
772
|
|
812
773
|
# retrieve
|
813
774
|
|
775
|
+
if exists(prev_layer_updates):
|
776
|
+
prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
777
|
+
|
778
|
+
updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
779
|
+
|
814
780
|
retrieved = self.retrieve_memories(
|
815
781
|
seq,
|
816
782
|
mem_model_weights + updates,
|
@@ -818,21 +784,8 @@ class NeuralMemory(Module):
|
|
818
784
|
prev_layer_updates = prev_layer_updates
|
819
785
|
)
|
820
786
|
|
821
|
-
# determine state for the storing of memories
|
822
|
-
# for transformer-xl like training with neural memory as well as inferencing with initial prompt
|
823
|
-
|
824
|
-
cache_store_seq = None
|
825
|
-
|
826
|
-
if remainder > 0:
|
827
|
-
cache_store_seq = store_seq[:, -remainder:]
|
828
|
-
|
829
|
-
next_store_state = NeuralMemCache(seq_len, cache_store_seq, next_state, updates)
|
830
|
-
|
831
787
|
output = (retrieved, next_store_state)
|
832
788
|
|
833
|
-
if return_values:
|
834
|
-
output = (*output, values)
|
835
|
-
|
836
789
|
if not return_aux_kv_loss:
|
837
790
|
return output
|
838
791
|
|
@@ -34,7 +34,6 @@ NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural memory,
|
|
34
34
|
NEURAL_MEM_GATE_ATTN_OUTPUT = False
|
35
35
|
NEURAL_MEM_MOMENTUM = True
|
36
36
|
NEURAL_MEM_QK_NORM = False
|
37
|
-
NEURAL_MEM_ADD_VALUE_RESIDUAL = False
|
38
37
|
WINDOW_SIZE = 32
|
39
38
|
NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
|
40
39
|
SLIDING_WINDOWS = True
|
@@ -91,7 +90,6 @@ model = MemoryAsContextTransformer(
|
|
91
90
|
use_flex_attn = USE_FLEX_ATTN,
|
92
91
|
sliding_window_attn = SLIDING_WINDOWS,
|
93
92
|
weight_tie_memory_model = WEIGHT_TIE_MEMORY_MODEL,
|
94
|
-
neural_memory_add_value_residual = NEURAL_MEM_ADD_VALUE_RESIDUAL,
|
95
93
|
neural_memory_model = MemoryMLP(
|
96
94
|
dim = 64,
|
97
95
|
depth = NEURAL_MEMORY_DEPTH
|
@@ -164,6 +162,6 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
|
|
164
162
|
prime = decode_tokens(inp)
|
165
163
|
print(f'%s \n\n %s', (prime, '*' * 100))
|
166
164
|
|
167
|
-
sample = model.sample(inp[None, ...], GENERATE_LENGTH, use_cache =
|
165
|
+
sample = model.sample(inp[None, ...], GENERATE_LENGTH, use_cache = USE_FAST_INFERENCE)
|
168
166
|
output_str = decode_tokens(sample[0])
|
169
167
|
print(output_str)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|