titans-pytorch 0.0.22__tar.gz → 0.0.23__tar.gz

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.22
3
+ Version: 0.0.23
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.22"
3
+ version = "0.0.23"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -100,11 +100,11 @@ class MemoryMLP(Module):
100
100
 
101
101
  # main neural memory
102
102
 
103
- def default_adaptive_step_transform(adaptive_step, max_lr = 1e-1):
103
+ def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
104
104
  return adaptive_step.sigmoid() * max_lr
105
105
 
106
106
  def default_loss_fn(pred, target):
107
- return (pred - target).pow(2).mean(dim = -1).sum()
107
+ return (pred - target).pow(2).mean(dim = -1)
108
108
 
109
109
  class NeuralMemory(Module):
110
110
  def __init__(
@@ -142,6 +142,12 @@ class NeuralMemory(Module):
142
142
  self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
143
143
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
144
144
 
145
+ self.retrieve_gate = nn.Sequential(
146
+ LinearNoBias(dim, heads),
147
+ Rearrange('b n h -> (b h) n 1'),
148
+ nn.Sigmoid()
149
+ ) if heads > 1 else None
150
+
145
151
  # memory mlp
146
152
 
147
153
  if not exists(model):
@@ -159,12 +165,13 @@ class NeuralMemory(Module):
159
165
 
160
166
  # prepare function for per sample gradients from model above, using torch.func
161
167
 
162
- def forward_and_loss(params, inputs, target):
168
+ def forward_and_loss(params, inputs, loss_weights, target):
163
169
  pred = functional_call(self.memory_model, params, inputs)
164
170
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
165
- return loss
171
+ loss = loss * loss_weights
172
+ return loss.sum()
166
173
 
167
- self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
174
+ self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))
168
175
 
169
176
  # queries for retrieving from the model
170
177
 
@@ -190,7 +197,6 @@ class NeuralMemory(Module):
190
197
  )
191
198
 
192
199
  self.to_adaptive_step = nn.Sequential(
193
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
194
200
  LinearNoBias(dim, heads),
195
201
  Rearrange('b n h -> (b h) n')
196
202
  )
@@ -271,9 +277,11 @@ class NeuralMemory(Module):
271
277
 
272
278
  keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
273
279
 
280
+ adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = self.chunk_size)
281
+
274
282
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
275
283
 
276
- grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
284
+ grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
277
285
 
278
286
  grads = TensorDict(grads)
279
287
 
@@ -286,9 +294,9 @@ class NeuralMemory(Module):
286
294
 
287
295
  grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
288
296
 
289
- # multiply gradients with learned adaptive step size
297
+ # negative gradients, adaptive lr already applied as loss weight
290
298
 
291
- surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
299
+ surprises = grads.apply(lambda t: -t)
292
300
 
293
301
  # determine scan function
294
302
 
@@ -405,6 +413,11 @@ class NeuralMemory(Module):
405
413
 
406
414
  values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
407
415
 
416
+ # maybe gate
417
+
418
+ if exists(self.retrieve_gate):
419
+ values = values * self.retrieve_gate(seq)
420
+
408
421
  # maybe merge heads and combine
409
422
 
410
423
  values = self.merge_heads(values)
@@ -27,7 +27,7 @@ LEARNING_RATE = 2e-4
27
27
  VALIDATE_EVERY = 100
28
28
  GENERATE_EVERY = 500
29
29
  GENERATE_LENGTH = 512
30
- SHOULD_GENERATE = False
30
+ SHOULD_GENERATE = True
31
31
  SEQ_LEN = 512
32
32
 
33
33
  PROJECT_NAME = 'titans-neural-memory'
File without changes