titans-pytorch 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +5 -4
- titans_pytorch/neural_memory.py +6 -8
- {titans_pytorch-0.3.10.dist-info → titans_pytorch-0.3.11.dist-info}/METADATA +2 -2
- titans_pytorch-0.3.11.dist-info/RECORD +9 -0
- titans_pytorch-0.3.10.dist-info/RECORD +0 -9
- {titans_pytorch-0.3.10.dist-info → titans_pytorch-0.3.11.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.3.10.dist-info → titans_pytorch-0.3.11.dist-info}/licenses/LICENSE +0 -0
@@ -62,6 +62,7 @@ from rotary_embedding_torch import RotaryEmbedding
|
|
62
62
|
# hyper connections / attend from x-transformers, which handles different queries and key lengths better
|
63
63
|
|
64
64
|
from x_transformers.attend import Attend
|
65
|
+
|
65
66
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
66
67
|
|
67
68
|
# proposed neural memory
|
@@ -515,7 +516,7 @@ class MemoryAsContextTransformer(Module):
|
|
515
516
|
|
516
517
|
# hyper conection
|
517
518
|
|
518
|
-
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
519
|
+
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1)
|
519
520
|
|
520
521
|
self.layers = ModuleList([])
|
521
522
|
|
@@ -553,7 +554,7 @@ class MemoryAsContextTransformer(Module):
|
|
553
554
|
mem_hyper_conn = None
|
554
555
|
|
555
556
|
if layer in neural_memory_layers:
|
556
|
-
mem_hyper_conn = init_hyper_conn(
|
557
|
+
mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
|
557
558
|
|
558
559
|
mem = NeuralMemory(
|
559
560
|
dim = dim,
|
@@ -571,8 +572,8 @@ class MemoryAsContextTransformer(Module):
|
|
571
572
|
self.layers.append(ModuleList([
|
572
573
|
mem_hyper_conn,
|
573
574
|
mem,
|
574
|
-
init_hyper_conn(
|
575
|
-
init_hyper_conn(
|
575
|
+
init_hyper_conn(branch = attn),
|
576
|
+
init_hyper_conn(branch = ff)
|
576
577
|
]))
|
577
578
|
|
578
579
|
self.norm = nn.RMSNorm(dim)
|
titans_pytorch/neural_memory.py
CHANGED
@@ -39,7 +39,7 @@ o - momentum orders
|
|
39
39
|
|
40
40
|
LinearNoBias = partial(Linear, bias = False)
|
41
41
|
|
42
|
-
|
42
|
+
NeuralMemState = namedtuple('NeuralMemState', [
|
43
43
|
'seq_index',
|
44
44
|
'weights',
|
45
45
|
'cache_store_segment',
|
@@ -629,7 +629,7 @@ class NeuralMemory(Module):
|
|
629
629
|
|
630
630
|
if num_chunks == 0:
|
631
631
|
updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...')
|
632
|
-
next_store_state =
|
632
|
+
next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, past_state, updates)
|
633
633
|
|
634
634
|
output = (updates, next_store_state)
|
635
635
|
|
@@ -682,13 +682,11 @@ class NeuralMemory(Module):
|
|
682
682
|
|
683
683
|
next_state = (next_last_update, next_last_momentum)
|
684
684
|
|
685
|
-
next_store_state =
|
685
|
+
next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, next_state, updates)
|
686
686
|
|
687
|
-
#
|
687
|
+
# return updates to neural memory at all chunked timesteps + neural mem cache / state to be fed back
|
688
688
|
|
689
|
-
|
690
|
-
|
691
|
-
return output
|
689
|
+
return updates, next_store_state
|
692
690
|
|
693
691
|
def retrieve_memories(
|
694
692
|
self,
|
@@ -785,7 +783,7 @@ class NeuralMemory(Module):
|
|
785
783
|
self,
|
786
784
|
seq,
|
787
785
|
store_seq = None,
|
788
|
-
state:
|
786
|
+
state: NeuralMemState | None = None,
|
789
787
|
prev_weights = None
|
790
788
|
):
|
791
789
|
if seq.ndim == 2:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.11
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
|
|
38
38
|
Requires-Dist: axial-positional-embedding>=0.3.10
|
39
39
|
Requires-Dist: einops>=0.8.0
|
40
40
|
Requires-Dist: einx>=0.3.0
|
41
|
-
Requires-Dist: hyper-connections>=0.1.
|
41
|
+
Requires-Dist: hyper-connections>=0.1.10
|
42
42
|
Requires-Dist: ninja
|
43
43
|
Requires-Dist: rotary-embedding-torch
|
44
44
|
Requires-Dist: tensordict
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
|
2
|
+
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=EyqA53HBqvAr4UNZUs37LR6IltyEfA7FKEV54YzVYlg,24945
|
4
|
+
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
+
titans_pytorch/neural_memory.py,sha256=7YglrQaDpKS2hbpBBwx7PmqhJdjyvFEPZDt_QXmnUMM,28878
|
6
|
+
titans_pytorch-0.3.11.dist-info/METADATA,sha256=xAEvavDiCj__5Bl_5UXaG__BycdUB2DzHOud-nwsn1c,6817
|
7
|
+
titans_pytorch-0.3.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.3.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.3.11.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
|
2
|
-
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
-
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
-
titans_pytorch/neural_memory.py,sha256=BeOnq41gjZeq-XJFjkHE44F9dLzsg9mm36EBYZ4wHMA,28814
|
6
|
-
titans_pytorch-0.3.10.dist-info/METADATA,sha256=sA_Dx_x5RMcpz5-vUPDHuz__tHYfKzs4W_BgY4CHPdk,6816
|
7
|
-
titans_pytorch-0.3.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.3.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.3.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|