titans-pytorch 0.0.17__tar.gz → 0.0.19__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/PKG-INFO +1 -1
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/pyproject.toml +1 -1
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/tests/test_titans.py +6 -2
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/titans_pytorch/titans.py +15 -6
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/train.py +0 -1
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/.gitignore +0 -0
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/LICENSE +0 -0
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/README.md +0 -0
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/data/README.md +0 -0
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/fig1.png +0 -0
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/fig2.png +0 -0
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/requirements.txt +0 -0
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.17 → titans_pytorch-0.0.19}/titans_pytorch/titans_attn_memory.py +0 -0
@@ -1,7 +1,11 @@
|
|
1
1
|
import torch
|
2
2
|
import pytest
|
3
3
|
|
4
|
-
|
4
|
+
@pytest.mark.parametrize('seq_len', (32, 1024, 77))
|
5
|
+
def test_titans(
|
6
|
+
seq_len
|
7
|
+
):
|
8
|
+
|
5
9
|
from titans_pytorch import NeuralMemory
|
6
10
|
|
7
11
|
mem = NeuralMemory(
|
@@ -9,7 +13,7 @@ def test_titans():
|
|
9
13
|
chunk_size = 64,
|
10
14
|
)
|
11
15
|
|
12
|
-
seq = torch.randn(2,
|
16
|
+
seq = torch.randn(2, seq_len, 384)
|
13
17
|
retrieved = mem(seq)
|
14
18
|
|
15
19
|
assert seq.shape == retrieved.shape
|
@@ -17,7 +17,7 @@ from titans_pytorch.associative_scan import (
|
|
17
17
|
)
|
18
18
|
|
19
19
|
import einx
|
20
|
-
from einops import rearrange, pack, unpack
|
20
|
+
from einops import rearrange, repeat, pack, unpack
|
21
21
|
from einops.layers.torch import Rearrange, Reduce
|
22
22
|
|
23
23
|
"""
|
@@ -152,6 +152,11 @@ class NeuralMemory(Module):
|
|
152
152
|
self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
|
153
153
|
self.store_memory_loss_fn = store_memory_loss_fn
|
154
154
|
|
155
|
+
# empty memory embed
|
156
|
+
|
157
|
+
self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
|
158
|
+
nn.init.normal_(self.empty_memory_embed, std = 0.02)
|
159
|
+
|
155
160
|
# learned adaptive learning rate and momentum
|
156
161
|
# todo - explore mlp layerwise learned lr / momentum
|
157
162
|
|
@@ -187,6 +192,9 @@ class NeuralMemory(Module):
|
|
187
192
|
|
188
193
|
return init_weights, init_momentum
|
189
194
|
|
195
|
+
def init_empty_memory_embed(self, batch, seq_len):
|
196
|
+
return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
|
197
|
+
|
190
198
|
def store_memories(
|
191
199
|
self,
|
192
200
|
seq,
|
@@ -269,7 +277,7 @@ class NeuralMemory(Module):
|
|
269
277
|
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
270
278
|
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
271
279
|
|
272
|
-
outputs = scan(gates, inputs)
|
280
|
+
outputs = scan(gates.contiguous(), inputs.contiguous())
|
273
281
|
|
274
282
|
outputs = outputs[..., :seq_len]
|
275
283
|
outputs = rearrange(outputs, 'b d n -> b n d')
|
@@ -372,11 +380,12 @@ class NeuralMemory(Module):
|
|
372
380
|
|
373
381
|
values = self.post_rmsnorm(values)
|
374
382
|
|
375
|
-
# restore
|
383
|
+
# restore, pad with empty memory embed
|
376
384
|
|
377
|
-
|
378
|
-
values = values
|
385
|
+
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
386
|
+
values = torch.cat((empty_memory_embeds, values), dim = -2)
|
379
387
|
|
388
|
+
values = values[:, :-padding]
|
380
389
|
return values
|
381
390
|
|
382
391
|
def forward(
|
@@ -389,7 +398,7 @@ class NeuralMemory(Module):
|
|
389
398
|
batch, seq_len = seq.shape[:2]
|
390
399
|
|
391
400
|
if seq_len < self.chunk_size:
|
392
|
-
return
|
401
|
+
return self.init_empty_memory_embed(batch, seq_len)
|
393
402
|
|
394
403
|
if exists(past_state):
|
395
404
|
past_state = tuple(TensorDict(d) for d in past_state)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|