titans-pytorch 0.0.22__py3-none-any.whl → 0.0.23__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
titans_pytorch/titans.py CHANGED
@@ -100,11 +100,11 @@ class MemoryMLP(Module):
100
100
 
101
101
  # main neural memory
102
102
 
103
- def default_adaptive_step_transform(adaptive_step, max_lr = 1e-1):
103
+ def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
104
104
  return adaptive_step.sigmoid() * max_lr
105
105
 
106
106
  def default_loss_fn(pred, target):
107
- return (pred - target).pow(2).mean(dim = -1).sum()
107
+ return (pred - target).pow(2).mean(dim = -1)
108
108
 
109
109
  class NeuralMemory(Module):
110
110
  def __init__(
@@ -142,6 +142,12 @@ class NeuralMemory(Module):
142
142
  self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
143
143
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
144
144
 
145
+ self.retrieve_gate = nn.Sequential(
146
+ LinearNoBias(dim, heads),
147
+ Rearrange('b n h -> (b h) n 1'),
148
+ nn.Sigmoid()
149
+ ) if heads > 1 else None
150
+
145
151
  # memory mlp
146
152
 
147
153
  if not exists(model):
@@ -159,12 +165,13 @@ class NeuralMemory(Module):
159
165
 
160
166
  # prepare function for per sample gradients from model above, using torch.func
161
167
 
162
- def forward_and_loss(params, inputs, target):
168
+ def forward_and_loss(params, inputs, loss_weights, target):
163
169
  pred = functional_call(self.memory_model, params, inputs)
164
170
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
165
- return loss
171
+ loss = loss * loss_weights
172
+ return loss.sum()
166
173
 
167
- self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
174
+ self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))
168
175
 
169
176
  # queries for retrieving from the model
170
177
 
@@ -190,7 +197,6 @@ class NeuralMemory(Module):
190
197
  )
191
198
 
192
199
  self.to_adaptive_step = nn.Sequential(
193
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
194
200
  LinearNoBias(dim, heads),
195
201
  Rearrange('b n h -> (b h) n')
196
202
  )
@@ -271,9 +277,11 @@ class NeuralMemory(Module):
271
277
 
272
278
  keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
273
279
 
280
+ adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = self.chunk_size)
281
+
274
282
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
275
283
 
276
- grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
284
+ grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
277
285
 
278
286
  grads = TensorDict(grads)
279
287
 
@@ -286,9 +294,9 @@ class NeuralMemory(Module):
286
294
 
287
295
  grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
288
296
 
289
- # multiply gradients with learned adaptive step size
297
+ # negative gradients, adaptive lr already applied as loss weight
290
298
 
291
- surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
299
+ surprises = grads.apply(lambda t: -t)
292
300
 
293
301
  # determine scan function
294
302
 
@@ -405,6 +413,11 @@ class NeuralMemory(Module):
405
413
 
406
414
  values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
407
415
 
416
+ # maybe gate
417
+
418
+ if exists(self.retrieve_gate):
419
+ values = values * self.retrieve_gate(seq)
420
+
408
421
  # maybe merge heads and combine
409
422
 
410
423
  values = self.merge_heads(values)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.22
3
+ Version: 0.0.23
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=AU2mf3RkClSRIP0IUwnnqsA5O1udNYGbTRb0lVBLA78,14024
4
+ titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
+ titans_pytorch-0.0.23.dist-info/METADATA,sha256=U80_8U_mwaQqwWKPWlu76-O3-CTCfrZ7t_HdE3Zl_qE,3811
6
+ titans_pytorch-0.0.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.23.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=s7KVqId7hhHw1Ck77FJUXfKC5rpwS4N7Nw2mKFtP8-s,13677
4
- titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
- titans_pytorch-0.0.22.dist-info/METADATA,sha256=acTS0vWh84sGtTIEBzdTsxNi-S4V-Cr8NL4U4Vg0eY0,3811
6
- titans_pytorch-0.0.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.22.dist-info/RECORD,,