titans-pytorch 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +36 -1
- {titans_pytorch-0.0.19.dist-info → titans_pytorch-0.0.21.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.21.dist-info/RECORD +8 -0
- titans_pytorch-0.0.19.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.19.dist-info → titans_pytorch-0.0.21.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.19.dist-info → titans_pytorch-0.0.21.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
@@ -40,6 +40,9 @@ def exists(v):
|
|
40
40
|
def default(v, d):
|
41
41
|
return v if exists(v) else d
|
42
42
|
|
43
|
+
def identity(t):
|
44
|
+
return t
|
45
|
+
|
43
46
|
def round_down_multiple(seq, mult):
|
44
47
|
return seq // mult * mult
|
45
48
|
|
@@ -55,6 +58,21 @@ def pack_one_with_inverse(t, pattern):
|
|
55
58
|
|
56
59
|
return packed, inverse
|
57
60
|
|
61
|
+
# softclamping gradients
|
62
|
+
|
63
|
+
def softclamp_max(t, max_value):
|
64
|
+
half_max_value = max_value / 2
|
65
|
+
return ((t / half_max_value).tanh() * half_max_value) + half_max_value
|
66
|
+
|
67
|
+
def softclamp_grad_norm(t, max_value):
|
68
|
+
t, inverse = pack_one_with_inverse(t, 'bn *')
|
69
|
+
|
70
|
+
norm = t.norm(dim = -1, keepdim = True)
|
71
|
+
clamped_norm = softclamp_max(norm, max_value)
|
72
|
+
|
73
|
+
t = t * (clamped_norm / norm)
|
74
|
+
return inverse(t)
|
75
|
+
|
58
76
|
# classes
|
59
77
|
|
60
78
|
class MemoryMLP(Module):
|
@@ -82,6 +100,9 @@ class MemoryMLP(Module):
|
|
82
100
|
|
83
101
|
# main neural memory
|
84
102
|
|
103
|
+
def default_adaptive_step_transform(adaptive_step):
|
104
|
+
return torch.exp(adaptive_step.sigmoid() * -15) # from 1. - 1e-7
|
105
|
+
|
85
106
|
def default_loss_fn(pred, target):
|
86
107
|
return (pred - target).pow(2).mean(dim = -1).sum()
|
87
108
|
|
@@ -94,8 +115,10 @@ class NeuralMemory(Module):
|
|
94
115
|
heads = 1,
|
95
116
|
model: Module | None = None,
|
96
117
|
store_memory_loss_fn: Callable = default_loss_fn,
|
118
|
+
adaptive_step_transform: Callable = default_adaptive_step_transform,
|
97
119
|
pre_rmsnorm = True,
|
98
120
|
post_rmsnorm = True,
|
121
|
+
max_grad_norm: float | None = None,
|
99
122
|
use_accelerated_scan = False,
|
100
123
|
default_mlp_kwargs: dict = dict(
|
101
124
|
depth = 4
|
@@ -172,6 +195,12 @@ class NeuralMemory(Module):
|
|
172
195
|
Rearrange('b n h -> (b h) n')
|
173
196
|
)
|
174
197
|
|
198
|
+
self.adaptive_step_transform = adaptive_step_transform
|
199
|
+
|
200
|
+
# allow for softclamp the gradient norms for storing memories
|
201
|
+
|
202
|
+
self.max_grad_norm = max_grad_norm
|
203
|
+
|
175
204
|
# weight decay factor
|
176
205
|
|
177
206
|
self.to_decay_factor = nn.Sequential(
|
@@ -222,7 +251,8 @@ class NeuralMemory(Module):
|
|
222
251
|
|
223
252
|
# pack batch and sequence dimension
|
224
253
|
|
225
|
-
adaptive_lr =
|
254
|
+
adaptive_lr = self.to_adaptive_step(seq)
|
255
|
+
adaptive_lr = self.adaptive_step_transform(adaptive_lr)
|
226
256
|
|
227
257
|
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
228
258
|
decay_factor = self.to_decay_factor(seq).sigmoid()
|
@@ -247,6 +277,11 @@ class NeuralMemory(Module):
|
|
247
277
|
|
248
278
|
grads = TensorDict(grads)
|
249
279
|
|
280
|
+
# maybe softclamp grad norm
|
281
|
+
|
282
|
+
if exists(self.max_grad_norm):
|
283
|
+
grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))
|
284
|
+
|
250
285
|
# restore batch and sequence dimension
|
251
286
|
|
252
287
|
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/titans.py,sha256=rsgpGs-nnqvsvF6Mu2GfmfNlRgYfHAV9MCEmd1ohUUI,13687
|
4
|
+
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
+
titans_pytorch-0.0.21.dist-info/METADATA,sha256=3z2LGviHUicQq4_ThlRasOUCDXIe0xX2aIhxtKLWI0Q,3811
|
6
|
+
titans_pytorch-0.0.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.0.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.0.21.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/titans.py,sha256=CxbJgNdIS9NbbCDdgotFXAnrV16xmvufUErerKe7qJA,12636
|
4
|
-
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
-
titans_pytorch-0.0.19.dist-info/METADATA,sha256=5Wpk79HYI4z8LeNRV__UaamKppiGcJ2HdIlll1JSZr8,3811
|
6
|
-
titans_pytorch-0.0.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.0.19.dist-info/RECORD,,
|
File without changes
|
File without changes
|