titans-pytorch 0.0.22__tar.gz → 0.0.24__tar.gz

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.22
3
+ Version: 0.0.24
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.22"
3
+ version = "0.0.24"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -73,6 +73,17 @@ def softclamp_grad_norm(t, max_value):
73
73
  t = t * (clamped_norm / norm)
74
74
  return inverse(t)
75
75
 
76
+ # multi head rmsnorm
77
+
78
+ class MultiheadRMSNorm(Module):
79
+ def __init__(self, dim, heads):
80
+ super().__init__()
81
+ self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
82
+ self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
83
+
84
+ def forward(self, x):
85
+ return self.rmsnorm(x) * (self.gamma + 1.)
86
+
76
87
  # classes
77
88
 
78
89
  class MemoryMLP(Module):
@@ -100,11 +111,11 @@ class MemoryMLP(Module):
100
111
 
101
112
  # main neural memory
102
113
 
103
- def default_adaptive_step_transform(adaptive_step, max_lr = 1e-1):
114
+ def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
104
115
  return adaptive_step.sigmoid() * max_lr
105
116
 
106
117
  def default_loss_fn(pred, target):
107
- return (pred - target).pow(2).mean(dim = -1).sum()
118
+ return (pred - target).pow(2).mean(dim = -1)
108
119
 
109
120
  class NeuralMemory(Module):
110
121
  def __init__(
@@ -131,17 +142,25 @@ class NeuralMemory(Module):
131
142
  self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
132
143
  self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
133
144
 
134
- self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
145
+ self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
135
146
 
136
147
  # maybe multi-headed
137
148
 
138
149
  dim_head = default(dim_head, dim)
139
150
  dim_inner = dim_head * heads
140
151
 
152
+ self.heads = heads
153
+
141
154
  self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
142
- self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
155
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
143
156
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
144
157
 
158
+ self.retrieve_gate = nn.Sequential(
159
+ LinearNoBias(dim, heads),
160
+ Rearrange('b n h -> b h n 1'),
161
+ nn.Sigmoid()
162
+ ) if heads > 1 else None
163
+
145
164
  # memory mlp
146
165
 
147
166
  if not exists(model):
@@ -159,12 +178,13 @@ class NeuralMemory(Module):
159
178
 
160
179
  # prepare function for per sample gradients from model above, using torch.func
161
180
 
162
- def forward_and_loss(params, inputs, target):
181
+ def forward_and_loss(params, inputs, loss_weights, target):
163
182
  pred = functional_call(self.memory_model, params, inputs)
164
183
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
165
- return loss
184
+ loss = loss * loss_weights
185
+ return loss.sum()
166
186
 
167
- self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
187
+ self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))
168
188
 
169
189
  # queries for retrieving from the model
170
190
 
@@ -190,7 +210,6 @@ class NeuralMemory(Module):
190
210
  )
191
211
 
192
212
  self.to_adaptive_step = nn.Sequential(
193
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
194
213
  LinearNoBias(dim, heads),
195
214
  Rearrange('b n h -> (b h) n')
196
215
  )
@@ -271,9 +290,11 @@ class NeuralMemory(Module):
271
290
 
272
291
  keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
273
292
 
293
+ adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = self.chunk_size)
294
+
274
295
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
275
296
 
276
- grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
297
+ grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
277
298
 
278
299
  grads = TensorDict(grads)
279
300
 
@@ -286,9 +307,9 @@ class NeuralMemory(Module):
286
307
 
287
308
  grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
288
309
 
289
- # multiply gradients with learned adaptive step size
310
+ # negative gradients, adaptive lr already applied as loss weight
290
311
 
291
- surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
312
+ surprises = grads.apply(lambda t: -t)
292
313
 
293
314
  # determine scan function
294
315
 
@@ -356,7 +377,7 @@ class NeuralMemory(Module):
356
377
  past_weights: dict[str, Tensor] | None = None,
357
378
  ):
358
379
  chunk_size = self.chunk_size
359
- seq_len = seq.shape[1]
380
+ batch, seq_len = seq.shape[:2]
360
381
 
361
382
  seq = self.retrieve_norm(seq)
362
383
 
@@ -390,8 +411,6 @@ class NeuralMemory(Module):
390
411
 
391
412
  queries = self.split_heads(queries)
392
413
 
393
- batch = queries.shape[0]
394
-
395
414
  # fetch values from memory model
396
415
 
397
416
  curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
@@ -403,7 +422,14 @@ class NeuralMemory(Module):
403
422
 
404
423
  # reconstitute batch dimension
405
424
 
406
- values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
425
+ values = rearrange(values, '(b h n) c d -> b h (n c) d', b = batch, h = self.heads)
426
+
427
+ values = self.multihead_rmsnorm(values)
428
+
429
+ # maybe gate
430
+
431
+ if exists(self.retrieve_gate):
432
+ values = values * self.retrieve_gate(seq)
407
433
 
408
434
  # maybe merge heads and combine
409
435
 
@@ -411,10 +437,6 @@ class NeuralMemory(Module):
411
437
 
412
438
  values = self.combine_heads(values)
413
439
 
414
- # post norm, somehow could not stabilize this without it, not in paper
415
-
416
- values = self.post_rmsnorm(values)
417
-
418
440
  # restore, pad with empty memory embed
419
441
 
420
442
  empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
@@ -27,12 +27,12 @@ LEARNING_RATE = 2e-4
27
27
  VALIDATE_EVERY = 100
28
28
  GENERATE_EVERY = 500
29
29
  GENERATE_LENGTH = 512
30
- SHOULD_GENERATE = False
30
+ SHOULD_GENERATE = True
31
31
  SEQ_LEN = 512
32
32
 
33
33
  PROJECT_NAME = 'titans-neural-memory'
34
34
  WANDB_ONLINE = False # turn this on to pipe experiment to cloud
35
- GLOBAL_LAYERS = (4, 5)
35
+ GLOBAL_LAYERS = (2, 4)
36
36
  USE_TITANS_MEMORY = True
37
37
  NEURAL_MEMORY_DEPTH = 2
38
38
  WINDOW_SIZE = 64
File without changes