titans-pytorch 0.0.37__py3-none-any.whl → 0.0.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +3 -2
- titans_pytorch/titans.py +8 -5
- {titans_pytorch-0.0.37.dist-info → titans_pytorch-0.0.38.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.38.dist-info/RECORD +9 -0
- titans_pytorch-0.0.37.dist-info/RECORD +0 -9
- {titans_pytorch-0.0.37.dist-info → titans_pytorch-0.0.38.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.37.dist-info → titans_pytorch-0.0.38.dist-info}/licenses/LICENSE +0 -0
@@ -288,7 +288,8 @@ class MemoryAsContextTransformer(Module):
|
|
288
288
|
for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
|
289
289
|
|
290
290
|
if exists(maybe_neural_mem):
|
291
|
-
|
291
|
+
x = maybe_neural_mem(x)
|
292
|
+
|
292
293
|
|
293
294
|
x = attn(x)
|
294
295
|
|
@@ -300,7 +301,7 @@ class MemoryAsContextTransformer(Module):
|
|
300
301
|
|
301
302
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
302
303
|
|
303
|
-
x,
|
304
|
+
x, _ = unpack(x, mem_ps, 'b * d')
|
304
305
|
|
305
306
|
x = inverse_segment(x)
|
306
307
|
|
titans_pytorch/titans.py
CHANGED
@@ -27,9 +27,7 @@ n - sequence
|
|
27
27
|
d - feature dimension
|
28
28
|
c - intra-chunk
|
29
29
|
"""
|
30
|
-
|
31
|
-
# constants
|
32
|
-
|
30
|
+
7
|
33
31
|
LinearNoBias = partial(Linear, bias = False)
|
34
32
|
|
35
33
|
# functions
|
@@ -390,7 +388,10 @@ class NeuralMemory(Module):
|
|
390
388
|
|
391
389
|
padding = next_seq_len - curtailed_seq_len
|
392
390
|
|
393
|
-
|
391
|
+
needs_pad = padding > 0
|
392
|
+
|
393
|
+
if needs_pad:
|
394
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
394
395
|
|
395
396
|
# the parameters of the memory model stores the memories of the key / values
|
396
397
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
@@ -442,7 +443,9 @@ class NeuralMemory(Module):
|
|
442
443
|
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
443
444
|
values = torch.cat((empty_memory_embeds, values), dim = -2)
|
444
445
|
|
445
|
-
|
446
|
+
if needs_pad:
|
447
|
+
values = values[:, :-padding]
|
448
|
+
|
446
449
|
return values
|
447
450
|
|
448
451
|
def forward(
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=5koIfEulJ841FNrs6URZfW2dp9LMuHzMkaySDrlbuP0,8393
|
4
|
+
titans_pytorch/titans.py,sha256=bv2Ceq-_4nNb5FNx4hLd2inC93m5MmJxO2-Mbf6PKK4,14378
|
5
|
+
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
+
titans_pytorch-0.0.38.dist-info/METADATA,sha256=L6tEQTEOXCeAU_BuRLbwUO0-gmnbJE-WQNAZ83BNCWA,3938
|
7
|
+
titans_pytorch-0.0.38.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.38.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.38.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=JjKGEMBit_SvyAsxq5v08614YBcLVx3OkM6pf0rADsA,8400
|
4
|
-
titans_pytorch/titans.py,sha256=ALICGfc6p2bD2QkaibyIceVLvBIRKXmDm-w7RjnVOe4,14304
|
5
|
-
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.37.dist-info/METADATA,sha256=RNokG8101_tlR0BiF-AxqYLZpXqafMSiN1Rg_pZe2-o,3938
|
7
|
-
titans_pytorch-0.0.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.0.37.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.0.37.dist-info/RECORD,,
|
File without changes
|
File without changes
|