titans-pytorch 0.0.20__py3-none-any.whl → 0.0.22__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- titans_pytorch/titans.py +11 -1
- {titans_pytorch-0.0.20.dist-info → titans_pytorch-0.0.22.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.22.dist-info/RECORD +8 -0
- titans_pytorch-0.0.20.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.20.dist-info → titans_pytorch-0.0.22.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.20.dist-info → titans_pytorch-0.0.22.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
@@ -40,6 +40,9 @@ def exists(v):
|
|
40
40
|
def default(v, d):
|
41
41
|
return v if exists(v) else d
|
42
42
|
|
43
|
+
def identity(t):
|
44
|
+
return t
|
45
|
+
|
43
46
|
def round_down_multiple(seq, mult):
|
44
47
|
return seq // mult * mult
|
45
48
|
|
@@ -97,6 +100,9 @@ class MemoryMLP(Module):
|
|
97
100
|
|
98
101
|
# main neural memory
|
99
102
|
|
103
|
+
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-1):
|
104
|
+
return adaptive_step.sigmoid() * max_lr
|
105
|
+
|
100
106
|
def default_loss_fn(pred, target):
|
101
107
|
return (pred - target).pow(2).mean(dim = -1).sum()
|
102
108
|
|
@@ -109,6 +115,7 @@ class NeuralMemory(Module):
|
|
109
115
|
heads = 1,
|
110
116
|
model: Module | None = None,
|
111
117
|
store_memory_loss_fn: Callable = default_loss_fn,
|
118
|
+
adaptive_step_transform: Callable = default_adaptive_step_transform,
|
112
119
|
pre_rmsnorm = True,
|
113
120
|
post_rmsnorm = True,
|
114
121
|
max_grad_norm: float | None = None,
|
@@ -188,6 +195,8 @@ class NeuralMemory(Module):
|
|
188
195
|
Rearrange('b n h -> (b h) n')
|
189
196
|
)
|
190
197
|
|
198
|
+
self.adaptive_step_transform = adaptive_step_transform
|
199
|
+
|
191
200
|
# allow for softclamp the gradient norms for storing memories
|
192
201
|
|
193
202
|
self.max_grad_norm = max_grad_norm
|
@@ -242,7 +251,8 @@ class NeuralMemory(Module):
|
|
242
251
|
|
243
252
|
# pack batch and sequence dimension
|
244
253
|
|
245
|
-
adaptive_lr =
|
254
|
+
adaptive_lr = self.to_adaptive_step(seq)
|
255
|
+
adaptive_lr = self.adaptive_step_transform(adaptive_lr)
|
246
256
|
|
247
257
|
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
248
258
|
decay_factor = self.to_decay_factor(seq).sigmoid()
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/titans.py,sha256=s7KVqId7hhHw1Ck77FJUXfKC5rpwS4N7Nw2mKFtP8-s,13677
|
4
|
+
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
+
titans_pytorch-0.0.22.dist-info/METADATA,sha256=acTS0vWh84sGtTIEBzdTsxNi-S4V-Cr8NL4U4Vg0eY0,3811
|
6
|
+
titans_pytorch-0.0.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.0.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.0.22.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/titans.py,sha256=HpVjFy6jyzLGB_ilqjcYWGE-VtYmUrUwkXzmzqPrCXc,13370
|
4
|
-
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
-
titans_pytorch-0.0.20.dist-info/METADATA,sha256=9qJWG-hwJ8IK9auhQV2XyEs54T0-LMvBAArF-iQ21IE,3811
|
6
|
-
titans_pytorch-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.0.20.dist-info/RECORD,,
|
File without changes
|
File without changes
|