titans-pytorch 0.4.0__tar.gz → 0.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/PKG-INFO +1 -1
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/pyproject.toml +1 -1
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/tests/test_titans.py +15 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/titans_pytorch/neural_memory.py +30 -9
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/.gitignore +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/LICENSE +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/README.md +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/data/README.md +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/data/enwik8.gz +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/fig1.png +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/fig2.png +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.4.0 → titans_pytorch-0.4.1}/train_mac.py +0 -0
@@ -74,6 +74,21 @@ def test_titans(
|
|
74
74
|
|
75
75
|
assert seq.shape == retrieved.shape
|
76
76
|
|
77
|
+
def test_return_surprises():
|
78
|
+
|
79
|
+
mem = NeuralMemory(
|
80
|
+
dim = 384,
|
81
|
+
chunk_size = 2,
|
82
|
+
dim_head = 64,
|
83
|
+
heads = 4,
|
84
|
+
)
|
85
|
+
|
86
|
+
seq = torch.randn(4, 64, 384)
|
87
|
+
|
88
|
+
_, _, surprises = mem(seq, return_surprises = True)
|
89
|
+
|
90
|
+
assert surprises.shape == (4, 4, 64)
|
91
|
+
|
77
92
|
@pytest.mark.parametrize('learned_momentum_combine', (False, True))
|
78
93
|
@pytest.mark.parametrize('learned_combine_include_zeroth', (False, True))
|
79
94
|
def test_titans_second_order_momentum(
|
@@ -353,11 +353,11 @@ class NeuralMemory(Module):
|
|
353
353
|
pred = functional_call(self.memory_model, params, inputs)
|
354
354
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
355
355
|
weighted_loss = loss * loss_weights
|
356
|
-
return weighted_loss.sum()
|
356
|
+
return weighted_loss.sum(), loss
|
357
357
|
|
358
358
|
# two functions
|
359
359
|
|
360
|
-
grad_fn = grad(forward_and_loss)
|
360
|
+
grad_fn = grad(forward_and_loss, has_aux = True)
|
361
361
|
|
362
362
|
self.per_sample_grad_fn = vmap(grad_fn, in_dims = (0, 0, 0, 0))
|
363
363
|
|
@@ -526,6 +526,7 @@ class NeuralMemory(Module):
|
|
526
526
|
seq_index = 0,
|
527
527
|
prev_weights = None,
|
528
528
|
mask: Tensor | None = None,
|
529
|
+
return_surprises = True
|
529
530
|
):
|
530
531
|
if self.qkv_receives_diff_views:
|
531
532
|
_, batch, seq_len = seq.shape[:3]
|
@@ -645,10 +646,14 @@ class NeuralMemory(Module):
|
|
645
646
|
|
646
647
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
647
648
|
|
648
|
-
grads = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
|
649
|
+
grads, unweighted_mem_model_loss = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
|
649
650
|
|
650
651
|
grads = TensorDict(grads)
|
651
652
|
|
653
|
+
# surprises
|
654
|
+
|
655
|
+
unweighted_mem_model_loss = rearrange(unweighted_mem_model_loss, '(b h n) c -> b h (n c)', b = batch, h = heads)
|
656
|
+
|
652
657
|
# maybe softclamp grad norm
|
653
658
|
|
654
659
|
if exists(self.max_grad_norm):
|
@@ -687,7 +692,10 @@ class NeuralMemory(Module):
|
|
687
692
|
|
688
693
|
output = (updates, next_store_state)
|
689
694
|
|
690
|
-
|
695
|
+
if not return_surprises:
|
696
|
+
return output
|
697
|
+
|
698
|
+
return (*output, unweighted_mem_model_loss)
|
691
699
|
|
692
700
|
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
693
701
|
|
@@ -744,7 +752,10 @@ class NeuralMemory(Module):
|
|
744
752
|
|
745
753
|
# return updates to neural memory at all chunked timesteps + neural mem cache / state to be fed back
|
746
754
|
|
747
|
-
|
755
|
+
if not return_surprises:
|
756
|
+
return updates, next_store_state
|
757
|
+
|
758
|
+
return updates, next_store_state, unweighted_mem_model_loss
|
748
759
|
|
749
760
|
def retrieve_memories(
|
750
761
|
self,
|
@@ -843,7 +854,8 @@ class NeuralMemory(Module):
|
|
843
854
|
store_seq = None,
|
844
855
|
state: NeuralMemState | None = None,
|
845
856
|
prev_weights = None,
|
846
|
-
store_mask: Tensor | None = None
|
857
|
+
store_mask: Tensor | None = None,
|
858
|
+
return_surprises = False
|
847
859
|
):
|
848
860
|
is_multi_input = self.qkv_receives_diff_views
|
849
861
|
|
@@ -927,6 +939,7 @@ class NeuralMemory(Module):
|
|
927
939
|
|
928
940
|
# whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
|
929
941
|
|
942
|
+
surprises = None
|
930
943
|
gate = None
|
931
944
|
|
932
945
|
if exists(self.transition_gate):
|
@@ -937,13 +950,14 @@ class NeuralMemory(Module):
|
|
937
950
|
|
938
951
|
# store
|
939
952
|
|
940
|
-
next_updates, next_neural_mem_state = self.store_memories(
|
953
|
+
next_updates, next_neural_mem_state, chunk_surprises = self.store_memories(
|
941
954
|
store_seq_chunk,
|
942
955
|
weights,
|
943
956
|
seq_index = seq_index,
|
944
957
|
past_state = past_state,
|
945
958
|
prev_weights = prev_weights,
|
946
|
-
mask = maybe_store_mask
|
959
|
+
mask = maybe_store_mask,
|
960
|
+
return_surprises = True
|
947
961
|
)
|
948
962
|
|
949
963
|
weights = next_neural_mem_state.weights
|
@@ -952,6 +966,8 @@ class NeuralMemory(Module):
|
|
952
966
|
|
953
967
|
updates = accum_updates(updates, next_updates)
|
954
968
|
|
969
|
+
surprises = safe_cat((surprises, chunk_surprises), dim = -1)
|
970
|
+
|
955
971
|
if is_last and not update_after_final_store:
|
956
972
|
continue
|
957
973
|
|
@@ -986,4 +1002,9 @@ class NeuralMemory(Module):
|
|
986
1002
|
updates
|
987
1003
|
)
|
988
1004
|
|
989
|
-
|
1005
|
+
# returning
|
1006
|
+
|
1007
|
+
if not return_surprises:
|
1008
|
+
return retrieved, next_neural_mem_state
|
1009
|
+
|
1010
|
+
return retrieved, next_neural_mem_state, surprises
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|