titans-pytorch 0.4.0__tar.gz → 0.4.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/PKG-INFO +1 -1
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/pyproject.toml +1 -1
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/tests/test_titans.py +15 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/titans_pytorch/neural_memory.py +31 -9
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/.gitignore +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/LICENSE +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/README.md +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/data/README.md +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/data/enwik8.gz +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/fig1.png +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/fig2.png +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.2}/train_mac.py +0 -0
@@ -74,6 +74,21 @@ def test_titans(
|
|
74
74
|
|
75
75
|
assert seq.shape == retrieved.shape
|
76
76
|
|
77
|
+
def test_return_surprises():
|
78
|
+
|
79
|
+
mem = NeuralMemory(
|
80
|
+
dim = 384,
|
81
|
+
chunk_size = 2,
|
82
|
+
dim_head = 64,
|
83
|
+
heads = 4,
|
84
|
+
)
|
85
|
+
|
86
|
+
seq = torch.randn(4, 64, 384)
|
87
|
+
|
88
|
+
_, _, (surprises, adaptive_lr) = mem(seq, return_surprises = True)
|
89
|
+
|
90
|
+
assert all([t.shape == (4, 4, 64) for t in (surprises, adaptive_lr)])
|
91
|
+
|
77
92
|
@pytest.mark.parametrize('learned_momentum_combine', (False, True))
|
78
93
|
@pytest.mark.parametrize('learned_combine_include_zeroth', (False, True))
|
79
94
|
def test_titans_second_order_momentum(
|
@@ -353,11 +353,11 @@ class NeuralMemory(Module):
|
|
353
353
|
pred = functional_call(self.memory_model, params, inputs)
|
354
354
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
355
355
|
weighted_loss = loss * loss_weights
|
356
|
-
return weighted_loss.sum()
|
356
|
+
return weighted_loss.sum(), loss
|
357
357
|
|
358
358
|
# two functions
|
359
359
|
|
360
|
-
grad_fn = grad(forward_and_loss)
|
360
|
+
grad_fn = grad(forward_and_loss, has_aux = True)
|
361
361
|
|
362
362
|
self.per_sample_grad_fn = vmap(grad_fn, in_dims = (0, 0, 0, 0))
|
363
363
|
|
@@ -526,6 +526,7 @@ class NeuralMemory(Module):
|
|
526
526
|
seq_index = 0,
|
527
527
|
prev_weights = None,
|
528
528
|
mask: Tensor | None = None,
|
529
|
+
return_surprises = True
|
529
530
|
):
|
530
531
|
if self.qkv_receives_diff_views:
|
531
532
|
_, batch, seq_len = seq.shape[:3]
|
@@ -645,10 +646,15 @@ class NeuralMemory(Module):
|
|
645
646
|
|
646
647
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
647
648
|
|
648
|
-
grads = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
|
649
|
+
grads, unweighted_mem_model_loss = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
|
649
650
|
|
650
651
|
grads = TensorDict(grads)
|
651
652
|
|
653
|
+
# surprises
|
654
|
+
|
655
|
+
adaptive_lr = rearrange(adaptive_lr, '(b h n) c -> b h (n c)', b = batch, h = heads)
|
656
|
+
unweighted_mem_model_loss = rearrange(unweighted_mem_model_loss, '(b h n) c -> b h (n c)', b = batch, h = heads)
|
657
|
+
|
652
658
|
# maybe softclamp grad norm
|
653
659
|
|
654
660
|
if exists(self.max_grad_norm):
|
@@ -687,7 +693,10 @@ class NeuralMemory(Module):
|
|
687
693
|
|
688
694
|
output = (updates, next_store_state)
|
689
695
|
|
690
|
-
|
696
|
+
if not return_surprises:
|
697
|
+
return output
|
698
|
+
|
699
|
+
return (*output, (unweighted_mem_model_loss, adaptive_lr))
|
691
700
|
|
692
701
|
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
693
702
|
|
@@ -744,7 +753,10 @@ class NeuralMemory(Module):
|
|
744
753
|
|
745
754
|
# return updates to neural memory at all chunked timesteps + neural mem cache / state to be fed back
|
746
755
|
|
747
|
-
|
756
|
+
if not return_surprises:
|
757
|
+
return updates, next_store_state
|
758
|
+
|
759
|
+
return updates, next_store_state, (unweighted_mem_model_loss, adaptive_lr)
|
748
760
|
|
749
761
|
def retrieve_memories(
|
750
762
|
self,
|
@@ -843,7 +855,8 @@ class NeuralMemory(Module):
|
|
843
855
|
store_seq = None,
|
844
856
|
state: NeuralMemState | None = None,
|
845
857
|
prev_weights = None,
|
846
|
-
store_mask: Tensor | None = None
|
858
|
+
store_mask: Tensor | None = None,
|
859
|
+
return_surprises = False
|
847
860
|
):
|
848
861
|
is_multi_input = self.qkv_receives_diff_views
|
849
862
|
|
@@ -927,6 +940,7 @@ class NeuralMemory(Module):
|
|
927
940
|
|
928
941
|
# whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
|
929
942
|
|
943
|
+
surprises = None
|
930
944
|
gate = None
|
931
945
|
|
932
946
|
if exists(self.transition_gate):
|
@@ -937,13 +951,14 @@ class NeuralMemory(Module):
|
|
937
951
|
|
938
952
|
# store
|
939
953
|
|
940
|
-
next_updates, next_neural_mem_state = self.store_memories(
|
954
|
+
next_updates, next_neural_mem_state, chunk_surprises = self.store_memories(
|
941
955
|
store_seq_chunk,
|
942
956
|
weights,
|
943
957
|
seq_index = seq_index,
|
944
958
|
past_state = past_state,
|
945
959
|
prev_weights = prev_weights,
|
946
|
-
mask = maybe_store_mask
|
960
|
+
mask = maybe_store_mask,
|
961
|
+
return_surprises = True
|
947
962
|
)
|
948
963
|
|
949
964
|
weights = next_neural_mem_state.weights
|
@@ -952,6 +967,8 @@ class NeuralMemory(Module):
|
|
952
967
|
|
953
968
|
updates = accum_updates(updates, next_updates)
|
954
969
|
|
970
|
+
surprises = safe_cat((surprises, chunk_surprises), dim = -1)
|
971
|
+
|
955
972
|
if is_last and not update_after_final_store:
|
956
973
|
continue
|
957
974
|
|
@@ -986,4 +1003,9 @@ class NeuralMemory(Module):
|
|
986
1003
|
updates
|
987
1004
|
)
|
988
1005
|
|
989
|
-
|
1006
|
+
# returning
|
1007
|
+
|
1008
|
+
if not return_surprises:
|
1009
|
+
return retrieved, next_neural_mem_state
|
1010
|
+
|
1011
|
+
return retrieved, next_neural_mem_state, surprises
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|