titans-pytorch 0.0.21__tar.gz → 0.0.22__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/PKG-INFO +1 -1
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/pyproject.toml +1 -1
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/tests/test_titans.py +4 -1
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/titans_pytorch/titans.py +2 -2
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/train.py +1 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/.gitignore +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/LICENSE +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/README.md +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/data/README.md +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/fig1.png +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/fig2.png +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/requirements.txt +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.22}/titans_pytorch/titans_attn_memory.py +0 -0
@@ -2,8 +2,10 @@ import torch
|
|
2
2
|
import pytest
|
3
3
|
|
4
4
|
@pytest.mark.parametrize('seq_len', (32, 1024, 77))
|
5
|
+
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
5
6
|
def test_titans(
|
6
|
-
seq_len
|
7
|
+
seq_len,
|
8
|
+
max_grad_norm
|
7
9
|
):
|
8
10
|
|
9
11
|
from titans_pytorch import NeuralMemory
|
@@ -11,6 +13,7 @@ def test_titans(
|
|
11
13
|
mem = NeuralMemory(
|
12
14
|
dim = 384,
|
13
15
|
chunk_size = 64,
|
16
|
+
max_grad_norm = max_grad_norm
|
14
17
|
)
|
15
18
|
|
16
19
|
seq = torch.randn(2, seq_len, 384)
|
@@ -100,8 +100,8 @@ class MemoryMLP(Module):
|
|
100
100
|
|
101
101
|
# main neural memory
|
102
102
|
|
103
|
-
def default_adaptive_step_transform(adaptive_step):
|
104
|
-
return
|
103
|
+
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-1):
|
104
|
+
return adaptive_step.sigmoid() * max_lr
|
105
105
|
|
106
106
|
def default_loss_fn(pred, target):
|
107
107
|
return (pred - target).pow(2).mean(dim = -1).sum()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|