titans-pytorch 0.0.19__tar.gz → 0.0.21__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/PKG-INFO +1 -1
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/pyproject.toml +1 -1
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/titans_pytorch/titans.py +36 -1
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/train.py +3 -3
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/.gitignore +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/LICENSE +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/README.md +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/data/README.md +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/fig1.png +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/fig2.png +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/requirements.txt +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.21}/titans_pytorch/titans_attn_memory.py +0 -0
@@ -40,6 +40,9 @@ def exists(v):
|
|
40
40
|
def default(v, d):
|
41
41
|
return v if exists(v) else d
|
42
42
|
|
43
|
+
def identity(t):
|
44
|
+
return t
|
45
|
+
|
43
46
|
def round_down_multiple(seq, mult):
|
44
47
|
return seq // mult * mult
|
45
48
|
|
@@ -55,6 +58,21 @@ def pack_one_with_inverse(t, pattern):
|
|
55
58
|
|
56
59
|
return packed, inverse
|
57
60
|
|
61
|
+
# softclamping gradients
|
62
|
+
|
63
|
+
def softclamp_max(t, max_value):
|
64
|
+
half_max_value = max_value / 2
|
65
|
+
return ((t / half_max_value).tanh() * half_max_value) + half_max_value
|
66
|
+
|
67
|
+
def softclamp_grad_norm(t, max_value):
|
68
|
+
t, inverse = pack_one_with_inverse(t, 'bn *')
|
69
|
+
|
70
|
+
norm = t.norm(dim = -1, keepdim = True)
|
71
|
+
clamped_norm = softclamp_max(norm, max_value)
|
72
|
+
|
73
|
+
t = t * (clamped_norm / norm)
|
74
|
+
return inverse(t)
|
75
|
+
|
58
76
|
# classes
|
59
77
|
|
60
78
|
class MemoryMLP(Module):
|
@@ -82,6 +100,9 @@ class MemoryMLP(Module):
|
|
82
100
|
|
83
101
|
# main neural memory
|
84
102
|
|
103
|
+
def default_adaptive_step_transform(adaptive_step):
|
104
|
+
return torch.exp(adaptive_step.sigmoid() * -15) # from 1. - 1e-7
|
105
|
+
|
85
106
|
def default_loss_fn(pred, target):
|
86
107
|
return (pred - target).pow(2).mean(dim = -1).sum()
|
87
108
|
|
@@ -94,8 +115,10 @@ class NeuralMemory(Module):
|
|
94
115
|
heads = 1,
|
95
116
|
model: Module | None = None,
|
96
117
|
store_memory_loss_fn: Callable = default_loss_fn,
|
118
|
+
adaptive_step_transform: Callable = default_adaptive_step_transform,
|
97
119
|
pre_rmsnorm = True,
|
98
120
|
post_rmsnorm = True,
|
121
|
+
max_grad_norm: float | None = None,
|
99
122
|
use_accelerated_scan = False,
|
100
123
|
default_mlp_kwargs: dict = dict(
|
101
124
|
depth = 4
|
@@ -172,6 +195,12 @@ class NeuralMemory(Module):
|
|
172
195
|
Rearrange('b n h -> (b h) n')
|
173
196
|
)
|
174
197
|
|
198
|
+
self.adaptive_step_transform = adaptive_step_transform
|
199
|
+
|
200
|
+
# allow for softclamp the gradient norms for storing memories
|
201
|
+
|
202
|
+
self.max_grad_norm = max_grad_norm
|
203
|
+
|
175
204
|
# weight decay factor
|
176
205
|
|
177
206
|
self.to_decay_factor = nn.Sequential(
|
@@ -222,7 +251,8 @@ class NeuralMemory(Module):
|
|
222
251
|
|
223
252
|
# pack batch and sequence dimension
|
224
253
|
|
225
|
-
adaptive_lr =
|
254
|
+
adaptive_lr = self.to_adaptive_step(seq)
|
255
|
+
adaptive_lr = self.adaptive_step_transform(adaptive_lr)
|
226
256
|
|
227
257
|
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
228
258
|
decay_factor = self.to_decay_factor(seq).sigmoid()
|
@@ -247,6 +277,11 @@ class NeuralMemory(Module):
|
|
247
277
|
|
248
278
|
grads = TensorDict(grads)
|
249
279
|
|
280
|
+
# maybe softclamp grad norm
|
281
|
+
|
282
|
+
if exists(self.max_grad_norm):
|
283
|
+
grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))
|
284
|
+
|
250
285
|
# restore batch and sequence dimension
|
251
286
|
|
252
287
|
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
@@ -62,11 +62,11 @@ def decode_tokens(tokens):
|
|
62
62
|
|
63
63
|
titans_neural_memory = NeuralMemory(
|
64
64
|
dim = 384,
|
65
|
-
chunk_size =
|
65
|
+
chunk_size = 4,
|
66
66
|
pre_rmsnorm = True,
|
67
67
|
post_rmsnorm = True,
|
68
|
-
dim_head =
|
69
|
-
heads =
|
68
|
+
dim_head = 64,
|
69
|
+
heads = 4,
|
70
70
|
use_accelerated_scan = True,
|
71
71
|
default_mlp_kwargs = dict(
|
72
72
|
depth = NEURAL_MEMORY_DEPTH
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|