titans-pytorch 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/neural_memory.py +20 -12
- {titans_pytorch-0.2.6.dist-info → titans_pytorch-0.2.7.dist-info}/METADATA +2 -2
- {titans_pytorch-0.2.6.dist-info → titans_pytorch-0.2.7.dist-info}/RECORD +5 -5
- {titans_pytorch-0.2.6.dist-info → titans_pytorch-0.2.7.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.2.6.dist-info → titans_pytorch-0.2.7.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/neural_memory.py
CHANGED
@@ -51,6 +51,9 @@ def default(*args):
|
|
51
51
|
return arg
|
52
52
|
return None
|
53
53
|
|
54
|
+
def identity(t):
|
55
|
+
return t
|
56
|
+
|
54
57
|
def xnor(x, y):
|
55
58
|
return not (x ^ y)
|
56
59
|
|
@@ -64,9 +67,6 @@ def safe_cat(inputs, dim = -2):
|
|
64
67
|
|
65
68
|
return cat(inputs, dim = dim)
|
66
69
|
|
67
|
-
def identity(t):
|
68
|
-
return t
|
69
|
-
|
70
70
|
def dict_get_shape(td):
|
71
71
|
return {k: v.shape for k, v in td.items()}
|
72
72
|
|
@@ -454,14 +454,14 @@ class NeuralMemory(Module):
|
|
454
454
|
|
455
455
|
weights = TensorDict(weights)
|
456
456
|
|
457
|
-
# allow for neural memory of a previous layer
|
458
|
-
|
459
|
-
|
457
|
+
# allow for neural memory of a previous layer to influence surprise of current layer
|
458
|
+
|
459
|
+
weights_for_surprise = weights
|
460
460
|
|
461
461
|
if exists(prev_layer_updates):
|
462
462
|
prev_layer_updates = TensorDict(prev_layer_updates)
|
463
463
|
|
464
|
-
|
464
|
+
weights_for_surprise = weights_for_surprise + prev_layer_updates
|
465
465
|
|
466
466
|
per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
|
467
467
|
|
@@ -506,11 +506,11 @@ class NeuralMemory(Module):
|
|
506
506
|
# flatten batch and time if surprise depends on previous layer memory model
|
507
507
|
|
508
508
|
if exists(prev_layer_updates):
|
509
|
-
|
509
|
+
weights_for_surprise = weights_for_surprise.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
510
510
|
|
511
511
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
512
512
|
|
513
|
-
grads, aux_kv_recon_loss = per_sample_grad_fn(dict(
|
513
|
+
grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
|
514
514
|
|
515
515
|
grads = TensorDict(grads)
|
516
516
|
|
@@ -536,7 +536,15 @@ class NeuralMemory(Module):
|
|
536
536
|
|
537
537
|
if not exists(past_state):
|
538
538
|
empty_dict = {key: None for key in weights.keys()}
|
539
|
-
|
539
|
+
|
540
|
+
# minibatch_init_weight corresponds to W0 in figure 7 of TTT paper
|
541
|
+
|
542
|
+
minibatch_init_weight = weights
|
543
|
+
|
544
|
+
if dict_get_shape(weights) == self.init_weight_shape:
|
545
|
+
minibatch_init_weight = weights.apply(lambda t: repeat(t, '... -> b 1 (...)', b = batch * heads))
|
546
|
+
|
547
|
+
past_state = (minibatch_init_weight, empty_dict)
|
540
548
|
|
541
549
|
past_last_update, past_last_momentum = past_state
|
542
550
|
|
@@ -734,7 +742,7 @@ class NeuralMemory(Module):
|
|
734
742
|
|
735
743
|
# retrieve
|
736
744
|
|
737
|
-
retrieved = self.retrieve_memories(token,
|
745
|
+
retrieved = self.retrieve_memories(token, weights, chunk_size = 1)
|
738
746
|
|
739
747
|
# next state tuple
|
740
748
|
|
@@ -796,7 +804,7 @@ class NeuralMemory(Module):
|
|
796
804
|
|
797
805
|
retrieved = self.retrieve_memories(
|
798
806
|
seq,
|
799
|
-
mem_model_weights
|
807
|
+
mem_model_weights,
|
800
808
|
chunk_size = chunk_size,
|
801
809
|
prev_layer_updates = prev_layer_updates
|
802
810
|
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.7
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -27,7 +27,7 @@ License: MIT License
|
|
27
27
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
28
28
|
SOFTWARE.
|
29
29
|
License-File: LICENSE
|
30
|
-
Keywords: artificial intelligence,deep learning,linear attention,
|
30
|
+
Keywords: artificial intelligence,deep learning,linear attention,memory,test time training
|
31
31
|
Classifier: Development Status :: 4 - Beta
|
32
32
|
Classifier: Intended Audience :: Developers
|
33
33
|
Classifier: License :: OSI Approved :: MIT License
|
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,27
|
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
3
|
titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
|
4
4
|
titans_pytorch/memory_models.py,sha256=CD8pQ-IUfTDvPmekuPTsZHE3Vy265QtbiUn_siJhA78,4064
|
5
|
-
titans_pytorch/neural_memory.py,sha256=
|
6
|
-
titans_pytorch-0.2.
|
7
|
-
titans_pytorch-0.2.
|
8
|
-
titans_pytorch-0.2.
|
9
|
-
titans_pytorch-0.2.
|
5
|
+
titans_pytorch/neural_memory.py,sha256=WAeR-nOpy1XbBP590By1-tCgirulqPbFGut4H1B77-g,24910
|
6
|
+
titans_pytorch-0.2.7.dist-info/METADATA,sha256=ndFb28pAe8xWmNU6oncV8VJDDPImo3aCuBv0d0JylIs,6811
|
7
|
+
titans_pytorch-0.2.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.2.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.2.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|