titans-pytorch 0.0.21__tar.gz → 0.0.23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.21
3
+ Version: 0.0.23
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.21"
3
+ version = "0.0.23"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -2,8 +2,10 @@ import torch
2
2
  import pytest
3
3
 
4
4
  @pytest.mark.parametrize('seq_len', (32, 1024, 77))
5
+ @pytest.mark.parametrize('max_grad_norm', (None, 2.))
5
6
  def test_titans(
6
- seq_len
7
+ seq_len,
8
+ max_grad_norm
7
9
  ):
8
10
 
9
11
  from titans_pytorch import NeuralMemory
@@ -11,6 +13,7 @@ def test_titans(
11
13
  mem = NeuralMemory(
12
14
  dim = 384,
13
15
  chunk_size = 64,
16
+ max_grad_norm = max_grad_norm
14
17
  )
15
18
 
16
19
  seq = torch.randn(2, seq_len, 384)
@@ -100,11 +100,11 @@ class MemoryMLP(Module):
100
100
 
101
101
  # main neural memory
102
102
 
103
- def default_adaptive_step_transform(adaptive_step):
104
- return torch.exp(adaptive_step.sigmoid() * -15) # from 1. - 1e-7
103
+ def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
104
+ return adaptive_step.sigmoid() * max_lr
105
105
 
106
106
  def default_loss_fn(pred, target):
107
- return (pred - target).pow(2).mean(dim = -1).sum()
107
+ return (pred - target).pow(2).mean(dim = -1)
108
108
 
109
109
  class NeuralMemory(Module):
110
110
  def __init__(
@@ -142,6 +142,12 @@ class NeuralMemory(Module):
142
142
  self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
143
143
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
144
144
 
145
+ self.retrieve_gate = nn.Sequential(
146
+ LinearNoBias(dim, heads),
147
+ Rearrange('b n h -> (b h) n 1'),
148
+ nn.Sigmoid()
149
+ ) if heads > 1 else None
150
+
145
151
  # memory mlp
146
152
 
147
153
  if not exists(model):
@@ -159,12 +165,13 @@ class NeuralMemory(Module):
159
165
 
160
166
  # prepare function for per sample gradients from model above, using torch.func
161
167
 
162
- def forward_and_loss(params, inputs, target):
168
+ def forward_and_loss(params, inputs, loss_weights, target):
163
169
  pred = functional_call(self.memory_model, params, inputs)
164
170
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
165
- return loss
171
+ loss = loss * loss_weights
172
+ return loss.sum()
166
173
 
167
- self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
174
+ self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))
168
175
 
169
176
  # queries for retrieving from the model
170
177
 
@@ -190,7 +197,6 @@ class NeuralMemory(Module):
190
197
  )
191
198
 
192
199
  self.to_adaptive_step = nn.Sequential(
193
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
194
200
  LinearNoBias(dim, heads),
195
201
  Rearrange('b n h -> (b h) n')
196
202
  )
@@ -271,9 +277,11 @@ class NeuralMemory(Module):
271
277
 
272
278
  keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
273
279
 
280
+ adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = self.chunk_size)
281
+
274
282
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
275
283
 
276
- grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
284
+ grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
277
285
 
278
286
  grads = TensorDict(grads)
279
287
 
@@ -286,9 +294,9 @@ class NeuralMemory(Module):
286
294
 
287
295
  grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
288
296
 
289
- # multiply gradients with learned adaptive step size
297
+ # negative gradients, adaptive lr already applied as loss weight
290
298
 
291
- surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
299
+ surprises = grads.apply(lambda t: -t)
292
300
 
293
301
  # determine scan function
294
302
 
@@ -405,6 +413,11 @@ class NeuralMemory(Module):
405
413
 
406
414
  values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
407
415
 
416
+ # maybe gate
417
+
418
+ if exists(self.retrieve_gate):
419
+ values = values * self.retrieve_gate(seq)
420
+
408
421
  # maybe merge heads and combine
409
422
 
410
423
  values = self.merge_heads(values)
@@ -27,7 +27,7 @@ LEARNING_RATE = 2e-4
27
27
  VALIDATE_EVERY = 100
28
28
  GENERATE_EVERY = 500
29
29
  GENERATE_LENGTH = 512
30
- SHOULD_GENERATE = False
30
+ SHOULD_GENERATE = True
31
31
  SEQ_LEN = 512
32
32
 
33
33
  PROJECT_NAME = 'titans-neural-memory'
@@ -67,6 +67,7 @@ titans_neural_memory = NeuralMemory(
67
67
  post_rmsnorm = True,
68
68
  dim_head = 64,
69
69
  heads = 4,
70
+ max_grad_norm = 1.,
70
71
  use_accelerated_scan = True,
71
72
  default_mlp_kwargs = dict(
72
73
  depth = NEURAL_MEMORY_DEPTH
File without changes