titans-pytorch 0.0.20__tar.gz → 0.0.21__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/PKG-INFO +1 -1
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/pyproject.toml +1 -1
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/titans_pytorch/titans.py +11 -1
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/train.py +3 -3
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/.gitignore +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/LICENSE +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/README.md +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/data/README.md +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/fig1.png +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/fig2.png +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/requirements.txt +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.21}/titans_pytorch/titans_attn_memory.py +0 -0
@@ -40,6 +40,9 @@ def exists(v):
|
|
40
40
|
def default(v, d):
|
41
41
|
return v if exists(v) else d
|
42
42
|
|
43
|
+
def identity(t):
|
44
|
+
return t
|
45
|
+
|
43
46
|
def round_down_multiple(seq, mult):
|
44
47
|
return seq // mult * mult
|
45
48
|
|
@@ -97,6 +100,9 @@ class MemoryMLP(Module):
|
|
97
100
|
|
98
101
|
# main neural memory
|
99
102
|
|
103
|
+
def default_adaptive_step_transform(adaptive_step):
|
104
|
+
return torch.exp(adaptive_step.sigmoid() * -15) # from 1. - 1e-7
|
105
|
+
|
100
106
|
def default_loss_fn(pred, target):
|
101
107
|
return (pred - target).pow(2).mean(dim = -1).sum()
|
102
108
|
|
@@ -109,6 +115,7 @@ class NeuralMemory(Module):
|
|
109
115
|
heads = 1,
|
110
116
|
model: Module | None = None,
|
111
117
|
store_memory_loss_fn: Callable = default_loss_fn,
|
118
|
+
adaptive_step_transform: Callable = default_adaptive_step_transform,
|
112
119
|
pre_rmsnorm = True,
|
113
120
|
post_rmsnorm = True,
|
114
121
|
max_grad_norm: float | None = None,
|
@@ -188,6 +195,8 @@ class NeuralMemory(Module):
|
|
188
195
|
Rearrange('b n h -> (b h) n')
|
189
196
|
)
|
190
197
|
|
198
|
+
self.adaptive_step_transform = adaptive_step_transform
|
199
|
+
|
191
200
|
# allow for softclamp the gradient norms for storing memories
|
192
201
|
|
193
202
|
self.max_grad_norm = max_grad_norm
|
@@ -242,7 +251,8 @@ class NeuralMemory(Module):
|
|
242
251
|
|
243
252
|
# pack batch and sequence dimension
|
244
253
|
|
245
|
-
adaptive_lr =
|
254
|
+
adaptive_lr = self.to_adaptive_step(seq)
|
255
|
+
adaptive_lr = self.adaptive_step_transform(adaptive_lr)
|
246
256
|
|
247
257
|
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
248
258
|
decay_factor = self.to_decay_factor(seq).sigmoid()
|
@@ -62,11 +62,11 @@ def decode_tokens(tokens):
|
|
62
62
|
|
63
63
|
titans_neural_memory = NeuralMemory(
|
64
64
|
dim = 384,
|
65
|
-
chunk_size =
|
65
|
+
chunk_size = 4,
|
66
66
|
pre_rmsnorm = True,
|
67
67
|
post_rmsnorm = True,
|
68
|
-
dim_head =
|
69
|
-
heads =
|
68
|
+
dim_head = 64,
|
69
|
+
heads = 4,
|
70
70
|
use_accelerated_scan = True,
|
71
71
|
default_mlp_kwargs = dict(
|
72
72
|
depth = NEURAL_MEMORY_DEPTH
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|