titans-pytorch 0.3.10__py3-none-any.whl → 0.3.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +5 -4
- titans_pytorch/neural_memory.py +18 -10
- {titans_pytorch-0.3.10.dist-info → titans_pytorch-0.3.12.dist-info}/METADATA +2 -2
- titans_pytorch-0.3.12.dist-info/RECORD +9 -0
- titans_pytorch-0.3.10.dist-info/RECORD +0 -9
- {titans_pytorch-0.3.10.dist-info → titans_pytorch-0.3.12.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.3.10.dist-info → titans_pytorch-0.3.12.dist-info}/licenses/LICENSE +0 -0
@@ -62,6 +62,7 @@ from rotary_embedding_torch import RotaryEmbedding
|
|
62
62
|
# hyper connections / attend from x-transformers, which handles different queries and key lengths better
|
63
63
|
|
64
64
|
from x_transformers.attend import Attend
|
65
|
+
|
65
66
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
66
67
|
|
67
68
|
# proposed neural memory
|
@@ -515,7 +516,7 @@ class MemoryAsContextTransformer(Module):
|
|
515
516
|
|
516
517
|
# hyper conection
|
517
518
|
|
518
|
-
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
519
|
+
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1)
|
519
520
|
|
520
521
|
self.layers = ModuleList([])
|
521
522
|
|
@@ -553,7 +554,7 @@ class MemoryAsContextTransformer(Module):
|
|
553
554
|
mem_hyper_conn = None
|
554
555
|
|
555
556
|
if layer in neural_memory_layers:
|
556
|
-
mem_hyper_conn = init_hyper_conn(
|
557
|
+
mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
|
557
558
|
|
558
559
|
mem = NeuralMemory(
|
559
560
|
dim = dim,
|
@@ -571,8 +572,8 @@ class MemoryAsContextTransformer(Module):
|
|
571
572
|
self.layers.append(ModuleList([
|
572
573
|
mem_hyper_conn,
|
573
574
|
mem,
|
574
|
-
init_hyper_conn(
|
575
|
-
init_hyper_conn(
|
575
|
+
init_hyper_conn(branch = attn),
|
576
|
+
init_hyper_conn(branch = ff)
|
576
577
|
]))
|
577
578
|
|
578
579
|
self.norm = nn.RMSNorm(dim)
|
titans_pytorch/neural_memory.py
CHANGED
@@ -7,7 +7,7 @@ from itertools import zip_longest
|
|
7
7
|
from collections import namedtuple
|
8
8
|
|
9
9
|
import torch
|
10
|
-
from torch import nn, cat, tensor, Tensor
|
10
|
+
from torch import nn, stack, cat, tensor, Tensor
|
11
11
|
import torch.nn.functional as F
|
12
12
|
from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict
|
13
13
|
from torch.func import functional_call, vmap, grad
|
@@ -39,7 +39,7 @@ o - momentum orders
|
|
39
39
|
|
40
40
|
LinearNoBias = partial(Linear, bias = False)
|
41
41
|
|
42
|
-
|
42
|
+
NeuralMemState = namedtuple('NeuralMemState', [
|
43
43
|
'seq_index',
|
44
44
|
'weights',
|
45
45
|
'cache_store_segment',
|
@@ -230,6 +230,7 @@ class NeuralMemory(Module):
|
|
230
230
|
momentum = True,
|
231
231
|
momentum_order = 1,
|
232
232
|
learned_momentum_combine = False,
|
233
|
+
learned_combine_include_zeroth = False,
|
233
234
|
pre_rmsnorm = True,
|
234
235
|
post_rmsnorm = False,
|
235
236
|
qk_rmsnorm = False,
|
@@ -399,12 +400,17 @@ class NeuralMemory(Module):
|
|
399
400
|
assert momentum
|
400
401
|
assert momentum_order > 1, 'only second order momentum allowed for now, but may allow learned combination of zeroth'
|
401
402
|
|
403
|
+
if learned_combine_include_zeroth:
|
404
|
+
momentum_order += 1
|
405
|
+
|
402
406
|
self.to_learned_momentum_combine = Sequential(
|
403
407
|
nn.Linear(dim, heads * momentum_order),
|
404
408
|
nn.Softmax(dim = -1),
|
405
409
|
Rearrange('b n (h o) -> o (b h) n', h = heads)
|
406
410
|
)
|
407
411
|
|
412
|
+
self.learned_combine_include_zeroth = learned_combine_include_zeroth
|
413
|
+
|
408
414
|
# per layer learning rate modulation
|
409
415
|
|
410
416
|
self.to_layer_modulation = Sequential(
|
@@ -629,7 +635,7 @@ class NeuralMemory(Module):
|
|
629
635
|
|
630
636
|
if num_chunks == 0:
|
631
637
|
updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...')
|
632
|
-
next_store_state =
|
638
|
+
next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, past_state, updates)
|
633
639
|
|
634
640
|
output = (updates, next_store_state)
|
635
641
|
|
@@ -662,10 +668,14 @@ class NeuralMemory(Module):
|
|
662
668
|
|
663
669
|
momentums.append(momentum)
|
664
670
|
|
665
|
-
momentums =
|
671
|
+
momentums = stack(momentums)
|
666
672
|
|
667
673
|
next_last_momentum[param_name] = momentums[:, :, -1] # momentums shape is Float['o bh n 1']
|
668
674
|
|
675
|
+
if learned_combine and self.learned_combine_include_zeroth:
|
676
|
+
# add the original surprise if learned combination of momentums
|
677
|
+
momentums = cat((rearrange(surprise, '... -> 1 ...'), momentums), dim = 0)
|
678
|
+
|
669
679
|
if not learned_combine:
|
670
680
|
update = momentums[-1]
|
671
681
|
else:
|
@@ -682,13 +692,11 @@ class NeuralMemory(Module):
|
|
682
692
|
|
683
693
|
next_state = (next_last_update, next_last_momentum)
|
684
694
|
|
685
|
-
next_store_state =
|
686
|
-
|
687
|
-
# returns
|
695
|
+
next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, next_state, updates)
|
688
696
|
|
689
|
-
|
697
|
+
# return updates to neural memory at all chunked timesteps + neural mem cache / state to be fed back
|
690
698
|
|
691
|
-
return
|
699
|
+
return updates, next_store_state
|
692
700
|
|
693
701
|
def retrieve_memories(
|
694
702
|
self,
|
@@ -785,7 +793,7 @@ class NeuralMemory(Module):
|
|
785
793
|
self,
|
786
794
|
seq,
|
787
795
|
store_seq = None,
|
788
|
-
state:
|
796
|
+
state: NeuralMemState | None = None,
|
789
797
|
prev_weights = None
|
790
798
|
):
|
791
799
|
if seq.ndim == 2:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.12
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
|
|
38
38
|
Requires-Dist: axial-positional-embedding>=0.3.10
|
39
39
|
Requires-Dist: einops>=0.8.0
|
40
40
|
Requires-Dist: einx>=0.3.0
|
41
|
-
Requires-Dist: hyper-connections>=0.1.
|
41
|
+
Requires-Dist: hyper-connections>=0.1.10
|
42
42
|
Requires-Dist: ninja
|
43
43
|
Requires-Dist: rotary-embedding-torch
|
44
44
|
Requires-Dist: tensordict
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
|
2
|
+
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=EyqA53HBqvAr4UNZUs37LR6IltyEfA7FKEV54YzVYlg,24945
|
4
|
+
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
+
titans_pytorch/neural_memory.py,sha256=VmUAS1xOM0ZfearWIzQrX_P7HI69viuwrg9M7BQByeE,29349
|
6
|
+
titans_pytorch-0.3.12.dist-info/METADATA,sha256=02OsMYNITFLjnKJgis8eUHxwcdH2aVbA_D-QK24TYbg,6817
|
7
|
+
titans_pytorch-0.3.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.3.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.3.12.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
|
2
|
-
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
-
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
-
titans_pytorch/neural_memory.py,sha256=BeOnq41gjZeq-XJFjkHE44F9dLzsg9mm36EBYZ4wHMA,28814
|
6
|
-
titans_pytorch-0.3.10.dist-info/METADATA,sha256=sA_Dx_x5RMcpz5-vUPDHuz__tHYfKzs4W_BgY4CHPdk,6816
|
7
|
-
titans_pytorch-0.3.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.3.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.3.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|