titans-pytorch 0.0.26__py3-none-any.whl → 0.0.27__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- titans_pytorch/mac_transformer.py +23 -2
- {titans_pytorch-0.0.26.dist-info → titans_pytorch-0.0.27.dist-info}/METADATA +1 -1
- {titans_pytorch-0.0.26.dist-info → titans_pytorch-0.0.27.dist-info}/RECORD +5 -5
- {titans_pytorch-0.0.26.dist-info → titans_pytorch-0.0.27.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.26.dist-info → titans_pytorch-0.0.27.dist-info}/licenses/LICENSE +0 -0
@@ -3,10 +3,11 @@ import math
|
|
3
3
|
from functools import partial
|
4
4
|
|
5
5
|
import torch
|
6
|
-
from torch import nn
|
6
|
+
from torch import nn, cat
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
+
from einops import repeat
|
10
11
|
from einops.layers.torch import Rearrange
|
11
12
|
|
12
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
@@ -48,6 +49,7 @@ class SegmentedAttention(Module):
|
|
48
49
|
self,
|
49
50
|
dim,
|
50
51
|
segment_len,
|
52
|
+
num_persist_mem_tokens,
|
51
53
|
dim_head = 64,
|
52
54
|
heads = 8,
|
53
55
|
):
|
@@ -67,6 +69,7 @@ class SegmentedAttention(Module):
|
|
67
69
|
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
|
68
70
|
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
|
69
71
|
|
72
|
+
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
70
73
|
|
71
74
|
def forward(self, seq):
|
72
75
|
batch, seq_len = seq.shape[:2]
|
@@ -92,6 +95,15 @@ class SegmentedAttention(Module):
|
|
92
95
|
q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
|
93
96
|
q, k, v = map(self.split_heads, (q, k, v))
|
94
97
|
|
98
|
+
# take care of persistent memory key / values
|
99
|
+
|
100
|
+
pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
|
101
|
+
|
102
|
+
k = cat((pmk, k), dim = -2)
|
103
|
+
v = cat((pmv, v), dim = -2)
|
104
|
+
|
105
|
+
# sdpa
|
106
|
+
|
95
107
|
out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
|
96
108
|
|
97
109
|
out = self.merge_heads(out)
|
@@ -113,6 +125,7 @@ class MemoryAsContextTransformer(Module):
|
|
113
125
|
dim,
|
114
126
|
depth,
|
115
127
|
segment_len,
|
128
|
+
num_persist_mem_tokens,
|
116
129
|
dim_head = 64,
|
117
130
|
heads = 8,
|
118
131
|
ff_mult = 4,
|
@@ -127,7 +140,14 @@ class MemoryAsContextTransformer(Module):
|
|
127
140
|
self.layers = ModuleList([])
|
128
141
|
|
129
142
|
for _ in range(depth):
|
130
|
-
attn = SegmentedAttention(
|
143
|
+
attn = SegmentedAttention(
|
144
|
+
dim = dim,
|
145
|
+
dim_head = dim_head,
|
146
|
+
heads = heads,
|
147
|
+
segment_len = segment_len,
|
148
|
+
num_persist_mem_tokens = num_persist_mem_tokens
|
149
|
+
)
|
150
|
+
|
131
151
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
132
152
|
|
133
153
|
self.layers.append(ModuleList([
|
@@ -162,6 +182,7 @@ if __name__ == '__main__':
|
|
162
182
|
num_tokens = 256,
|
163
183
|
dim = 256,
|
164
184
|
depth = 2,
|
185
|
+
num_persist_mem_tokens = 16,
|
165
186
|
segment_len = 128,
|
166
187
|
)
|
167
188
|
|
@@ -1,9 +1,9 @@
|
|
1
1
|
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=4xUSfGDdVsR-WmeXX7yXoFfybROvNCjOxL_EHDJ_Wlk,4681
|
4
4
|
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
5
5
|
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.
|
7
|
-
titans_pytorch-0.0.
|
8
|
-
titans_pytorch-0.0.
|
9
|
-
titans_pytorch-0.0.
|
6
|
+
titans_pytorch-0.0.27.dist-info/METADATA,sha256=qJp1IDbphEUfW7EyNvQ7RfmHuvB7SH5h_tlnCVwV4EY,3851
|
7
|
+
titans_pytorch-0.0.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.27.dist-info/RECORD,,
|
File without changes
|
File without changes
|