titans-pytorch 0.0.2__tar.gz → 0.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.2 → titans_pytorch-0.0.4}/PKG-INFO +2 -2
- {titans_pytorch-0.0.2 → titans_pytorch-0.0.4}/README.md +1 -1
- {titans_pytorch-0.0.2 → titans_pytorch-0.0.4}/pyproject.toml +1 -1
- {titans_pytorch-0.0.2 → titans_pytorch-0.0.4}/titans_pytorch/titans.py +46 -21
- {titans_pytorch-0.0.2 → titans_pytorch-0.0.4}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.2 → titans_pytorch-0.0.4}/.gitignore +0 -0
- {titans_pytorch-0.0.2 → titans_pytorch-0.0.4}/LICENSE +0 -0
- {titans_pytorch-0.0.2 → titans_pytorch-0.0.4}/fig1.png +0 -0
- {titans_pytorch-0.0.2 → titans_pytorch-0.0.4}/fig2.png +0 -0
- {titans_pytorch-0.0.2 → titans_pytorch-0.0.4}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.2 → titans_pytorch-0.0.4}/titans_pytorch/associative_scan.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.4
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -49,7 +49,7 @@ Description-Content-Type: text/markdown
|
|
|
49
49
|
|
|
50
50
|
## Titans - Pytorch (wip)
|
|
51
51
|
|
|
52
|
-
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module.
|
|
52
|
+
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
|
53
53
|
|
|
54
54
|
## Install
|
|
55
55
|
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
## Titans - Pytorch (wip)
|
|
6
6
|
|
|
7
|
-
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module.
|
|
7
|
+
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
|
8
8
|
|
|
9
9
|
## Install
|
|
10
10
|
|
|
@@ -16,7 +16,7 @@ from titans_pytorch.associative_scan import (
|
|
|
16
16
|
|
|
17
17
|
import einx
|
|
18
18
|
from einops import rearrange, pack, unpack
|
|
19
|
-
from einops.layers.torch import Rearrange
|
|
19
|
+
from einops.layers.torch import Rearrange, Reduce
|
|
20
20
|
|
|
21
21
|
"""
|
|
22
22
|
ein notation:
|
|
@@ -84,6 +84,7 @@ class NeuralMemory(Module):
|
|
|
84
84
|
def __init__(
|
|
85
85
|
self,
|
|
86
86
|
dim,
|
|
87
|
+
chunk_size = 1,
|
|
87
88
|
model: Module | None = None,
|
|
88
89
|
store_memory_loss_fn: Callable = default_loss_fn
|
|
89
90
|
):
|
|
@@ -98,11 +99,15 @@ class NeuralMemory(Module):
|
|
|
98
99
|
|
|
99
100
|
self.memory_model = model
|
|
100
101
|
|
|
102
|
+
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
|
103
|
+
|
|
104
|
+
self.chunk_size = chunk_size
|
|
105
|
+
|
|
101
106
|
# prepare function for per sample gradients from model above, using torch.func
|
|
102
107
|
|
|
103
108
|
def forward_and_loss(params, inputs, target):
|
|
104
109
|
pred = functional_call(self.memory_model, params, inputs)
|
|
105
|
-
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k)
|
|
110
|
+
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
|
106
111
|
return loss
|
|
107
112
|
|
|
108
113
|
self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
|
|
@@ -119,9 +124,23 @@ class NeuralMemory(Module):
|
|
|
119
124
|
# learned adaptive learning rate and momentum
|
|
120
125
|
# todo - explore mlp layerwise learned lr / momentum
|
|
121
126
|
|
|
122
|
-
self.to_momentum =
|
|
123
|
-
|
|
124
|
-
|
|
127
|
+
self.to_momentum = nn.Sequential(
|
|
128
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
129
|
+
LinearNoBias(dim, 1)
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
self.to_adaptive_step = nn.Sequential(
|
|
133
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
134
|
+
LinearNoBias(dim, 1),
|
|
135
|
+
Rearrange('... 1 -> ...')
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# weight decay factor
|
|
139
|
+
|
|
140
|
+
self.to_decay_factor = nn.Sequential(
|
|
141
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
142
|
+
LinearNoBias(dim, 1)
|
|
143
|
+
)
|
|
125
144
|
|
|
126
145
|
def init_weights_and_momentum(self):
|
|
127
146
|
params = TensorDict(dict(self.memory_model.named_parameters()))
|
|
@@ -137,6 +156,16 @@ class NeuralMemory(Module):
|
|
|
137
156
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
|
|
138
157
|
):
|
|
139
158
|
|
|
159
|
+
# curtail sequence by multiple of the chunk size
|
|
160
|
+
# only a complete chunk of the sequence provides the memory for the next chunk
|
|
161
|
+
|
|
162
|
+
seq_len = seq.shape[-2]
|
|
163
|
+
round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
|
|
164
|
+
|
|
165
|
+
seq = seq[:, :round_down_seq_len]
|
|
166
|
+
|
|
167
|
+
# curr weights + past weights, in the case that the initial weights are learned
|
|
168
|
+
|
|
140
169
|
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
|
141
170
|
|
|
142
171
|
past_state = tuple(TensorDict(d) for d in past_state)
|
|
@@ -148,16 +177,19 @@ class NeuralMemory(Module):
|
|
|
148
177
|
|
|
149
178
|
batch = seq.shape[0]
|
|
150
179
|
|
|
151
|
-
adaptive_lr = self.to_adaptive_step(seq).tanh() * 0.5 +
|
|
180
|
+
adaptive_lr = self.to_adaptive_step(seq).tanh() * 0.5 + 0.5
|
|
152
181
|
|
|
153
182
|
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
|
154
183
|
decay_factor = self.to_decay_factor(seq).sigmoid()
|
|
155
184
|
|
|
156
185
|
# keys and values
|
|
157
186
|
|
|
158
|
-
seq = rearrange(seq, 'b n d -> (b n) d')
|
|
159
187
|
keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
|
|
160
188
|
|
|
189
|
+
# take care of chunking
|
|
190
|
+
|
|
191
|
+
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
|
|
192
|
+
|
|
161
193
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
|
162
194
|
|
|
163
195
|
grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
|
|
@@ -172,31 +204,24 @@ class NeuralMemory(Module):
|
|
|
172
204
|
|
|
173
205
|
surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
|
|
174
206
|
|
|
175
|
-
#
|
|
207
|
+
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
|
176
208
|
|
|
177
209
|
next_momentum = TensorDict()
|
|
210
|
+
updates = TensorDict()
|
|
178
211
|
|
|
179
212
|
for param_name, surprise in surprises.items():
|
|
180
213
|
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
|
|
181
214
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
momentum = inverse_pack(momentum)
|
|
185
|
-
|
|
186
|
-
next_momentum[param_name] = momentum
|
|
215
|
+
# derive momentum with associative scan - eq (10)
|
|
187
216
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
updates = TensorDict()
|
|
217
|
+
_, momentum = associative_scan(binary_operator, (adaptive_momentum, surprise)) # momentum is S / surprise in the paper
|
|
191
218
|
|
|
192
|
-
|
|
193
|
-
momentum, inverse_pack = pack_one_with_inverse(momentum, 'b n *')
|
|
219
|
+
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
|
194
220
|
|
|
195
221
|
_, update = associative_scan(binary_operator, (1. - decay_factor, momentum)) # momentum is S / surprise in the paper
|
|
196
222
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
updates[param_name] = update
|
|
223
|
+
updates[param_name] = inverse_pack(update)
|
|
224
|
+
next_momentum[param_name] = inverse_pack(momentum)
|
|
200
225
|
|
|
201
226
|
# compute the next weight per batch
|
|
202
227
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|