titans-pytorch 0.0.39__py3-none-any.whl → 0.0.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +32 -3
- {titans_pytorch-0.0.39.dist-info → titans_pytorch-0.0.40.dist-info}/METADATA +1 -1
- {titans_pytorch-0.0.39.dist-info → titans_pytorch-0.0.40.dist-info}/RECORD +5 -5
- {titans_pytorch-0.0.39.dist-info → titans_pytorch-0.0.40.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.39.dist-info → titans_pytorch-0.0.40.dist-info}/licenses/LICENSE +0 -0
|
@@ -94,6 +94,7 @@ class SegmentedAttention(Module):
|
|
|
94
94
|
num_longterm_mem_tokens = 0,
|
|
95
95
|
dim_head = 64,
|
|
96
96
|
heads = 8,
|
|
97
|
+
accept_value_residual = False,
|
|
97
98
|
attend_kwargs: dict = dict()
|
|
98
99
|
):
|
|
99
100
|
super().__init__()
|
|
@@ -108,6 +109,12 @@ class SegmentedAttention(Module):
|
|
|
108
109
|
self.to_qkv = LinearNoBias(dim, dim_inner * 3)
|
|
109
110
|
self.to_out = LinearNoBias(dim_inner, dim)
|
|
110
111
|
|
|
112
|
+
self.to_learned_v_mix = nn.Sequential(
|
|
113
|
+
nn.Linear(dim, heads),
|
|
114
|
+
Rearrange('b n h -> b h n 1'),
|
|
115
|
+
nn.Sigmoid()
|
|
116
|
+
) if accept_value_residual else None
|
|
117
|
+
|
|
111
118
|
self.segment_len = segment_len
|
|
112
119
|
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
|
113
120
|
|
|
@@ -118,7 +125,13 @@ class SegmentedAttention(Module):
|
|
|
118
125
|
|
|
119
126
|
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
|
120
127
|
|
|
121
|
-
def forward(
|
|
128
|
+
def forward(
|
|
129
|
+
self,
|
|
130
|
+
seq,
|
|
131
|
+
value_residual = None
|
|
132
|
+
):
|
|
133
|
+
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
134
|
+
|
|
122
135
|
segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
|
|
123
136
|
total_segment_len = segment_len + num_longterm_mem_tokens
|
|
124
137
|
|
|
@@ -136,6 +149,14 @@ class SegmentedAttention(Module):
|
|
|
136
149
|
q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
|
|
137
150
|
q, k, v = map(self.split_heads, (q, k, v))
|
|
138
151
|
|
|
152
|
+
# value residual
|
|
153
|
+
|
|
154
|
+
orig_v = v
|
|
155
|
+
|
|
156
|
+
if exists(self.to_learned_v_mix):
|
|
157
|
+
mix = self.to_learned_v_mix(seq)
|
|
158
|
+
v = v.lerp(value_residual, mix)
|
|
159
|
+
|
|
139
160
|
# take care of persistent memory key / values
|
|
140
161
|
|
|
141
162
|
pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
|
|
@@ -159,7 +180,7 @@ class SegmentedAttention(Module):
|
|
|
159
180
|
|
|
160
181
|
out = inverse_segment(out)
|
|
161
182
|
|
|
162
|
-
return out
|
|
183
|
+
return out, orig_v
|
|
163
184
|
|
|
164
185
|
# MAC transformer
|
|
165
186
|
|
|
@@ -210,6 +231,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
210
231
|
assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
|
|
211
232
|
|
|
212
233
|
for layer in layers:
|
|
234
|
+
is_first = layer == 1
|
|
213
235
|
|
|
214
236
|
# neural memory
|
|
215
237
|
|
|
@@ -235,6 +257,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
235
257
|
dim_head = dim_head,
|
|
236
258
|
heads = heads,
|
|
237
259
|
segment_len = segment_len,
|
|
260
|
+
accept_value_residual = not is_first,
|
|
238
261
|
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
|
239
262
|
num_persist_mem_tokens = num_persist_mem_tokens
|
|
240
263
|
)
|
|
@@ -285,6 +308,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
285
308
|
pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
|
|
286
309
|
x = x + pos_emb[:x.shape[-2]]
|
|
287
310
|
|
|
311
|
+
# value residual
|
|
312
|
+
|
|
313
|
+
value_residual = None
|
|
314
|
+
|
|
288
315
|
# expand and reduce streams for hyper connections
|
|
289
316
|
|
|
290
317
|
x = self.expand_streams(x)
|
|
@@ -295,7 +322,9 @@ class MemoryAsContextTransformer(Module):
|
|
|
295
322
|
x = maybe_neural_mem(x)
|
|
296
323
|
|
|
297
324
|
|
|
298
|
-
x = attn(x)
|
|
325
|
+
x, values = attn(x, value_residual = value_residual)
|
|
326
|
+
|
|
327
|
+
value_residual = default(value_residual, values)
|
|
299
328
|
|
|
300
329
|
x = ff(x)
|
|
301
330
|
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
|
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=y9sruSvGCEL4flu_RW7bCdvIe-S9dEdGacbmPYL1kqA,9311
|
|
4
4
|
titans_pytorch/titans.py,sha256=bv2Ceq-_4nNb5FNx4hLd2inC93m5MmJxO2-Mbf6PKK4,14378
|
|
5
5
|
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
|
6
|
-
titans_pytorch-0.0.
|
|
7
|
-
titans_pytorch-0.0.
|
|
8
|
-
titans_pytorch-0.0.
|
|
9
|
-
titans_pytorch-0.0.
|
|
6
|
+
titans_pytorch-0.0.40.dist-info/METADATA,sha256=JCJ5aG9_-rVUErW6u-DXkJtVQ52Bf3XQDN3puirXAXo,3968
|
|
7
|
+
titans_pytorch-0.0.40.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
8
|
+
titans_pytorch-0.0.40.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
9
|
+
titans_pytorch-0.0.40.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|