titans-pytorch 0.0.19__tar.gz → 0.0.20__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/PKG-INFO +1 -1
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/pyproject.toml +1 -1
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/titans_pytorch/titans.py +25 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/.gitignore +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/LICENSE +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/README.md +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/data/README.md +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/fig1.png +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/fig2.png +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/requirements.txt +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.19 → titans_pytorch-0.0.20}/train.py +0 -0
@@ -55,6 +55,21 @@ def pack_one_with_inverse(t, pattern):
|
|
55
55
|
|
56
56
|
return packed, inverse
|
57
57
|
|
58
|
+
# softclamping gradients
|
59
|
+
|
60
|
+
def softclamp_max(t, max_value):
|
61
|
+
half_max_value = max_value / 2
|
62
|
+
return ((t / half_max_value).tanh() * half_max_value) + half_max_value
|
63
|
+
|
64
|
+
def softclamp_grad_norm(t, max_value):
|
65
|
+
t, inverse = pack_one_with_inverse(t, 'bn *')
|
66
|
+
|
67
|
+
norm = t.norm(dim = -1, keepdim = True)
|
68
|
+
clamped_norm = softclamp_max(norm, max_value)
|
69
|
+
|
70
|
+
t = t * (clamped_norm / norm)
|
71
|
+
return inverse(t)
|
72
|
+
|
58
73
|
# classes
|
59
74
|
|
60
75
|
class MemoryMLP(Module):
|
@@ -96,6 +111,7 @@ class NeuralMemory(Module):
|
|
96
111
|
store_memory_loss_fn: Callable = default_loss_fn,
|
97
112
|
pre_rmsnorm = True,
|
98
113
|
post_rmsnorm = True,
|
114
|
+
max_grad_norm: float | None = None,
|
99
115
|
use_accelerated_scan = False,
|
100
116
|
default_mlp_kwargs: dict = dict(
|
101
117
|
depth = 4
|
@@ -172,6 +188,10 @@ class NeuralMemory(Module):
|
|
172
188
|
Rearrange('b n h -> (b h) n')
|
173
189
|
)
|
174
190
|
|
191
|
+
# allow for softclamp the gradient norms for storing memories
|
192
|
+
|
193
|
+
self.max_grad_norm = max_grad_norm
|
194
|
+
|
175
195
|
# weight decay factor
|
176
196
|
|
177
197
|
self.to_decay_factor = nn.Sequential(
|
@@ -247,6 +267,11 @@ class NeuralMemory(Module):
|
|
247
267
|
|
248
268
|
grads = TensorDict(grads)
|
249
269
|
|
270
|
+
# maybe softclamp grad norm
|
271
|
+
|
272
|
+
if exists(self.max_grad_norm):
|
273
|
+
grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))
|
274
|
+
|
250
275
|
# restore batch and sequence dimension
|
251
276
|
|
252
277
|
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|