titans-pytorch 0.0.37__py3-none-any.whl → 0.0.38__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- titans_pytorch/mac_transformer.py +3 -2
- titans_pytorch/titans.py +8 -5
- {titans_pytorch-0.0.37.dist-info → titans_pytorch-0.0.38.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.38.dist-info/RECORD +9 -0
- titans_pytorch-0.0.37.dist-info/RECORD +0 -9
- {titans_pytorch-0.0.37.dist-info → titans_pytorch-0.0.38.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.37.dist-info → titans_pytorch-0.0.38.dist-info}/licenses/LICENSE +0 -0
@@ -288,7 +288,8 @@ class MemoryAsContextTransformer(Module):
|
|
288
288
|
for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
|
289
289
|
|
290
290
|
if exists(maybe_neural_mem):
|
291
|
-
|
291
|
+
x = maybe_neural_mem(x)
|
292
|
+
|
292
293
|
|
293
294
|
x = attn(x)
|
294
295
|
|
@@ -300,7 +301,7 @@ class MemoryAsContextTransformer(Module):
|
|
300
301
|
|
301
302
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
302
303
|
|
303
|
-
x,
|
304
|
+
x, _ = unpack(x, mem_ps, 'b * d')
|
304
305
|
|
305
306
|
x = inverse_segment(x)
|
306
307
|
|
titans_pytorch/titans.py
CHANGED
@@ -27,9 +27,7 @@ n - sequence
|
|
27
27
|
d - feature dimension
|
28
28
|
c - intra-chunk
|
29
29
|
"""
|
30
|
-
|
31
|
-
# constants
|
32
|
-
|
30
|
+
7
|
33
31
|
LinearNoBias = partial(Linear, bias = False)
|
34
32
|
|
35
33
|
# functions
|
@@ -390,7 +388,10 @@ class NeuralMemory(Module):
|
|
390
388
|
|
391
389
|
padding = next_seq_len - curtailed_seq_len
|
392
390
|
|
393
|
-
|
391
|
+
needs_pad = padding > 0
|
392
|
+
|
393
|
+
if needs_pad:
|
394
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
394
395
|
|
395
396
|
# the parameters of the memory model stores the memories of the key / values
|
396
397
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
@@ -442,7 +443,9 @@ class NeuralMemory(Module):
|
|
442
443
|
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
443
444
|
values = torch.cat((empty_memory_embeds, values), dim = -2)
|
444
445
|
|
445
|
-
|
446
|
+
if needs_pad:
|
447
|
+
values = values[:, :-padding]
|
448
|
+
|
446
449
|
return values
|
447
450
|
|
448
451
|
def forward(
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=5koIfEulJ841FNrs6URZfW2dp9LMuHzMkaySDrlbuP0,8393
|
4
|
+
titans_pytorch/titans.py,sha256=bv2Ceq-_4nNb5FNx4hLd2inC93m5MmJxO2-Mbf6PKK4,14378
|
5
|
+
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
+
titans_pytorch-0.0.38.dist-info/METADATA,sha256=L6tEQTEOXCeAU_BuRLbwUO0-gmnbJE-WQNAZ83BNCWA,3938
|
7
|
+
titans_pytorch-0.0.38.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.38.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.38.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=JjKGEMBit_SvyAsxq5v08614YBcLVx3OkM6pf0rADsA,8400
|
4
|
-
titans_pytorch/titans.py,sha256=ALICGfc6p2bD2QkaibyIceVLvBIRKXmDm-w7RjnVOe4,14304
|
5
|
-
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.37.dist-info/METADATA,sha256=RNokG8101_tlR0BiF-AxqYLZpXqafMSiN1Rg_pZe2-o,3938
|
7
|
-
titans_pytorch-0.0.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.0.37.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.0.37.dist-info/RECORD,,
|
File without changes
|
File without changes
|