titans-pytorch 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +4 -4
- {titans_pytorch-0.0.1.dist-info → titans_pytorch-0.0.3.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.3.dist-info/RECORD +7 -0
- titans_pytorch-0.0.1.dist-info/RECORD +0 -7
- {titans_pytorch-0.0.1.dist-info → titans_pytorch-0.0.3.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.1.dist-info → titans_pytorch-0.0.3.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
|
@@ -121,7 +121,7 @@ class NeuralMemory(Module):
|
|
|
121
121
|
|
|
122
122
|
self.to_momentum = LinearNoBias(dim, 1)
|
|
123
123
|
self.to_adaptive_step = nn.Sequential(LinearNoBias(dim, 1), Rearrange('... 1 -> ...'))
|
|
124
|
-
self.to_decay_factor =
|
|
124
|
+
self.to_decay_factor = LinearNoBias(dim, 1) # weight decay factor
|
|
125
125
|
|
|
126
126
|
def init_weights_and_momentum(self):
|
|
127
127
|
params = TensorDict(dict(self.memory_model.named_parameters()))
|
|
@@ -148,10 +148,10 @@ class NeuralMemory(Module):
|
|
|
148
148
|
|
|
149
149
|
batch = seq.shape[0]
|
|
150
150
|
|
|
151
|
-
adaptive_lr = self.to_adaptive_step(seq)
|
|
152
|
-
adaptive_momentum = self.to_momentum(seq)
|
|
151
|
+
adaptive_lr = self.to_adaptive_step(seq).tanh() * 0.5 + 0.5.
|
|
153
152
|
|
|
154
|
-
|
|
153
|
+
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
|
154
|
+
decay_factor = self.to_decay_factor(seq).sigmoid()
|
|
155
155
|
|
|
156
156
|
# keys and values
|
|
157
157
|
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/titans.py,sha256=0Mh9LJv5hLVbB2MvRJX5QanAeTtU9LAuj6YOQUwsyUQ,7813
|
|
4
|
+
titans_pytorch-0.0.3.dist-info/METADATA,sha256=AXfDl_MTIu24VRagi_rgiH8rHXFBU5euwSD6DMwLgsg,2968
|
|
5
|
+
titans_pytorch-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
+
titans_pytorch-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
+
titans_pytorch-0.0.3.dist-info/RECORD,,
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/titans.py,sha256=xty74Q3xQ174uycscfOnh-zgxGMH882lrIA_KGvxTUU,7802
|
|
4
|
-
titans_pytorch-0.0.1.dist-info/METADATA,sha256=HqR3VxpV5e-dPLLEbuOekC161-2r2WwKBCvK7E2MhAs,2968
|
|
5
|
-
titans_pytorch-0.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
titans_pytorch-0.0.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
-
titans_pytorch-0.0.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|