titans-pytorch 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- titans_pytorch/titans.py +39 -5
- {titans_pytorch-0.0.18.dist-info → titans_pytorch-0.0.20.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.20.dist-info/RECORD +8 -0
- titans_pytorch-0.0.18.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.18.dist-info → titans_pytorch-0.0.20.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.18.dist-info → titans_pytorch-0.0.20.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
@@ -17,7 +17,7 @@ from titans_pytorch.associative_scan import (
|
|
17
17
|
)
|
18
18
|
|
19
19
|
import einx
|
20
|
-
from einops import rearrange, pack, unpack
|
20
|
+
from einops import rearrange, repeat, pack, unpack
|
21
21
|
from einops.layers.torch import Rearrange, Reduce
|
22
22
|
|
23
23
|
"""
|
@@ -55,6 +55,21 @@ def pack_one_with_inverse(t, pattern):
|
|
55
55
|
|
56
56
|
return packed, inverse
|
57
57
|
|
58
|
+
# softclamping gradients
|
59
|
+
|
60
|
+
def softclamp_max(t, max_value):
|
61
|
+
half_max_value = max_value / 2
|
62
|
+
return ((t / half_max_value).tanh() * half_max_value) + half_max_value
|
63
|
+
|
64
|
+
def softclamp_grad_norm(t, max_value):
|
65
|
+
t, inverse = pack_one_with_inverse(t, 'bn *')
|
66
|
+
|
67
|
+
norm = t.norm(dim = -1, keepdim = True)
|
68
|
+
clamped_norm = softclamp_max(norm, max_value)
|
69
|
+
|
70
|
+
t = t * (clamped_norm / norm)
|
71
|
+
return inverse(t)
|
72
|
+
|
58
73
|
# classes
|
59
74
|
|
60
75
|
class MemoryMLP(Module):
|
@@ -96,6 +111,7 @@ class NeuralMemory(Module):
|
|
96
111
|
store_memory_loss_fn: Callable = default_loss_fn,
|
97
112
|
pre_rmsnorm = True,
|
98
113
|
post_rmsnorm = True,
|
114
|
+
max_grad_norm: float | None = None,
|
99
115
|
use_accelerated_scan = False,
|
100
116
|
default_mlp_kwargs: dict = dict(
|
101
117
|
depth = 4
|
@@ -152,6 +168,11 @@ class NeuralMemory(Module):
|
|
152
168
|
self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
|
153
169
|
self.store_memory_loss_fn = store_memory_loss_fn
|
154
170
|
|
171
|
+
# empty memory embed
|
172
|
+
|
173
|
+
self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
|
174
|
+
nn.init.normal_(self.empty_memory_embed, std = 0.02)
|
175
|
+
|
155
176
|
# learned adaptive learning rate and momentum
|
156
177
|
# todo - explore mlp layerwise learned lr / momentum
|
157
178
|
|
@@ -167,6 +188,10 @@ class NeuralMemory(Module):
|
|
167
188
|
Rearrange('b n h -> (b h) n')
|
168
189
|
)
|
169
190
|
|
191
|
+
# allow for softclamp the gradient norms for storing memories
|
192
|
+
|
193
|
+
self.max_grad_norm = max_grad_norm
|
194
|
+
|
170
195
|
# weight decay factor
|
171
196
|
|
172
197
|
self.to_decay_factor = nn.Sequential(
|
@@ -187,6 +212,9 @@ class NeuralMemory(Module):
|
|
187
212
|
|
188
213
|
return init_weights, init_momentum
|
189
214
|
|
215
|
+
def init_empty_memory_embed(self, batch, seq_len):
|
216
|
+
return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
|
217
|
+
|
190
218
|
def store_memories(
|
191
219
|
self,
|
192
220
|
seq,
|
@@ -239,6 +267,11 @@ class NeuralMemory(Module):
|
|
239
267
|
|
240
268
|
grads = TensorDict(grads)
|
241
269
|
|
270
|
+
# maybe softclamp grad norm
|
271
|
+
|
272
|
+
if exists(self.max_grad_norm):
|
273
|
+
grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))
|
274
|
+
|
242
275
|
# restore batch and sequence dimension
|
243
276
|
|
244
277
|
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
@@ -372,11 +405,12 @@ class NeuralMemory(Module):
|
|
372
405
|
|
373
406
|
values = self.post_rmsnorm(values)
|
374
407
|
|
375
|
-
# restore
|
408
|
+
# restore, pad with empty memory embed
|
376
409
|
|
377
|
-
|
378
|
-
values = values
|
410
|
+
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
411
|
+
values = torch.cat((empty_memory_embeds, values), dim = -2)
|
379
412
|
|
413
|
+
values = values[:, :-padding]
|
380
414
|
return values
|
381
415
|
|
382
416
|
def forward(
|
@@ -389,7 +423,7 @@ class NeuralMemory(Module):
|
|
389
423
|
batch, seq_len = seq.shape[:2]
|
390
424
|
|
391
425
|
if seq_len < self.chunk_size:
|
392
|
-
return
|
426
|
+
return self.init_empty_memory_embed(batch, seq_len)
|
393
427
|
|
394
428
|
if exists(past_state):
|
395
429
|
past_state = tuple(TensorDict(d) for d in past_state)
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/titans.py,sha256=HpVjFy6jyzLGB_ilqjcYWGE-VtYmUrUwkXzmzqPrCXc,13370
|
4
|
+
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
+
titans_pytorch-0.0.20.dist-info/METADATA,sha256=9qJWG-hwJ8IK9auhQV2XyEs54T0-LMvBAArF-iQ21IE,3811
|
6
|
+
titans_pytorch-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.0.20.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/titans.py,sha256=-Xv3ufD2vhprNFliuu1lGx27nx7AvHi6yFG2g9eHaqY,12295
|
4
|
-
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
-
titans_pytorch-0.0.18.dist-info/METADATA,sha256=YX0EPMqVioQjAVxoI3CAKV8nWgwZZ0tw4djgud4bEqs,3811
|
6
|
-
titans_pytorch-0.0.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.0.18.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.0.18.dist-info/RECORD,,
|
File without changes
|
File without changes
|