titans-pytorch 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +11 -1
- {titans_pytorch-0.0.20.dist-info → titans_pytorch-0.0.21.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.21.dist-info/RECORD +8 -0
- titans_pytorch-0.0.20.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.20.dist-info → titans_pytorch-0.0.21.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.20.dist-info → titans_pytorch-0.0.21.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
@@ -40,6 +40,9 @@ def exists(v):
|
|
40
40
|
def default(v, d):
|
41
41
|
return v if exists(v) else d
|
42
42
|
|
43
|
+
def identity(t):
|
44
|
+
return t
|
45
|
+
|
43
46
|
def round_down_multiple(seq, mult):
|
44
47
|
return seq // mult * mult
|
45
48
|
|
@@ -97,6 +100,9 @@ class MemoryMLP(Module):
|
|
97
100
|
|
98
101
|
# main neural memory
|
99
102
|
|
103
|
+
def default_adaptive_step_transform(adaptive_step):
|
104
|
+
return torch.exp(adaptive_step.sigmoid() * -15) # from 1. - 1e-7
|
105
|
+
|
100
106
|
def default_loss_fn(pred, target):
|
101
107
|
return (pred - target).pow(2).mean(dim = -1).sum()
|
102
108
|
|
@@ -109,6 +115,7 @@ class NeuralMemory(Module):
|
|
109
115
|
heads = 1,
|
110
116
|
model: Module | None = None,
|
111
117
|
store_memory_loss_fn: Callable = default_loss_fn,
|
118
|
+
adaptive_step_transform: Callable = default_adaptive_step_transform,
|
112
119
|
pre_rmsnorm = True,
|
113
120
|
post_rmsnorm = True,
|
114
121
|
max_grad_norm: float | None = None,
|
@@ -188,6 +195,8 @@ class NeuralMemory(Module):
|
|
188
195
|
Rearrange('b n h -> (b h) n')
|
189
196
|
)
|
190
197
|
|
198
|
+
self.adaptive_step_transform = adaptive_step_transform
|
199
|
+
|
191
200
|
# allow for softclamp the gradient norms for storing memories
|
192
201
|
|
193
202
|
self.max_grad_norm = max_grad_norm
|
@@ -242,7 +251,8 @@ class NeuralMemory(Module):
|
|
242
251
|
|
243
252
|
# pack batch and sequence dimension
|
244
253
|
|
245
|
-
adaptive_lr =
|
254
|
+
adaptive_lr = self.to_adaptive_step(seq)
|
255
|
+
adaptive_lr = self.adaptive_step_transform(adaptive_lr)
|
246
256
|
|
247
257
|
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
248
258
|
decay_factor = self.to_decay_factor(seq).sigmoid()
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/titans.py,sha256=rsgpGs-nnqvsvF6Mu2GfmfNlRgYfHAV9MCEmd1ohUUI,13687
|
4
|
+
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
+
titans_pytorch-0.0.21.dist-info/METADATA,sha256=3z2LGviHUicQq4_ThlRasOUCDXIe0xX2aIhxtKLWI0Q,3811
|
6
|
+
titans_pytorch-0.0.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.0.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.0.21.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/titans.py,sha256=HpVjFy6jyzLGB_ilqjcYWGE-VtYmUrUwkXzmzqPrCXc,13370
|
4
|
-
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
-
titans_pytorch-0.0.20.dist-info/METADATA,sha256=9qJWG-hwJ8IK9auhQV2XyEs54T0-LMvBAArF-iQ21IE,3811
|
6
|
-
titans_pytorch-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.0.20.dist-info/RECORD,,
|
File without changes
|
File without changes
|