titans-pytorch 0.0.22__tar.gz → 0.0.23__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/PKG-INFO +1 -1
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/pyproject.toml +1 -1
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/titans_pytorch/titans.py +22 -9
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/train.py +1 -1
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/.gitignore +0 -0
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/LICENSE +0 -0
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/README.md +0 -0
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/data/README.md +0 -0
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/fig1.png +0 -0
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/fig2.png +0 -0
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/requirements.txt +0 -0
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.22 → titans_pytorch-0.0.23}/titans_pytorch/titans_attn_memory.py +0 -0
@@ -100,11 +100,11 @@ class MemoryMLP(Module):
|
|
100
100
|
|
101
101
|
# main neural memory
|
102
102
|
|
103
|
-
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-
|
103
|
+
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
|
104
104
|
return adaptive_step.sigmoid() * max_lr
|
105
105
|
|
106
106
|
def default_loss_fn(pred, target):
|
107
|
-
return (pred - target).pow(2).mean(dim = -1)
|
107
|
+
return (pred - target).pow(2).mean(dim = -1)
|
108
108
|
|
109
109
|
class NeuralMemory(Module):
|
110
110
|
def __init__(
|
@@ -142,6 +142,12 @@ class NeuralMemory(Module):
|
|
142
142
|
self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
|
143
143
|
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
144
144
|
|
145
|
+
self.retrieve_gate = nn.Sequential(
|
146
|
+
LinearNoBias(dim, heads),
|
147
|
+
Rearrange('b n h -> (b h) n 1'),
|
148
|
+
nn.Sigmoid()
|
149
|
+
) if heads > 1 else None
|
150
|
+
|
145
151
|
# memory mlp
|
146
152
|
|
147
153
|
if not exists(model):
|
@@ -159,12 +165,13 @@ class NeuralMemory(Module):
|
|
159
165
|
|
160
166
|
# prepare function for per sample gradients from model above, using torch.func
|
161
167
|
|
162
|
-
def forward_and_loss(params, inputs, target):
|
168
|
+
def forward_and_loss(params, inputs, loss_weights, target):
|
163
169
|
pred = functional_call(self.memory_model, params, inputs)
|
164
170
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
165
|
-
|
171
|
+
loss = loss * loss_weights
|
172
|
+
return loss.sum()
|
166
173
|
|
167
|
-
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
|
174
|
+
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))
|
168
175
|
|
169
176
|
# queries for retrieving from the model
|
170
177
|
|
@@ -190,7 +197,6 @@ class NeuralMemory(Module):
|
|
190
197
|
)
|
191
198
|
|
192
199
|
self.to_adaptive_step = nn.Sequential(
|
193
|
-
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
194
200
|
LinearNoBias(dim, heads),
|
195
201
|
Rearrange('b n h -> (b h) n')
|
196
202
|
)
|
@@ -271,9 +277,11 @@ class NeuralMemory(Module):
|
|
271
277
|
|
272
278
|
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
|
273
279
|
|
280
|
+
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = self.chunk_size)
|
281
|
+
|
274
282
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
275
283
|
|
276
|
-
grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
|
284
|
+
grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
|
277
285
|
|
278
286
|
grads = TensorDict(grads)
|
279
287
|
|
@@ -286,9 +294,9 @@ class NeuralMemory(Module):
|
|
286
294
|
|
287
295
|
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
288
296
|
|
289
|
-
#
|
297
|
+
# negative gradients, adaptive lr already applied as loss weight
|
290
298
|
|
291
|
-
surprises = grads.apply(lambda t:
|
299
|
+
surprises = grads.apply(lambda t: -t)
|
292
300
|
|
293
301
|
# determine scan function
|
294
302
|
|
@@ -405,6 +413,11 @@ class NeuralMemory(Module):
|
|
405
413
|
|
406
414
|
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
407
415
|
|
416
|
+
# maybe gate
|
417
|
+
|
418
|
+
if exists(self.retrieve_gate):
|
419
|
+
values = values * self.retrieve_gate(seq)
|
420
|
+
|
408
421
|
# maybe merge heads and combine
|
409
422
|
|
410
423
|
values = self.merge_heads(values)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|