titans-pytorch 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -73,6 +73,17 @@ def softclamp_grad_norm(t, max_value):
73
73
  t = t * (clamped_norm / norm)
74
74
  return inverse(t)
75
75
 
76
+ # multi head rmsnorm
77
+
78
+ class MultiheadRMSNorm(Module):
79
+ def __init__(self, dim, heads):
80
+ super().__init__()
81
+ self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
82
+ self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
83
+
84
+ def forward(self, x):
85
+ return self.rmsnorm(x) * (self.gamma + 1.)
86
+
76
87
  # classes
77
88
 
78
89
  class MemoryMLP(Module):
@@ -100,11 +111,11 @@ class MemoryMLP(Module):
100
111
 
101
112
  # main neural memory
102
113
 
103
- def default_adaptive_step_transform(adaptive_step, max_lr = 1e-1):
114
+ def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
104
115
  return adaptive_step.sigmoid() * max_lr
105
116
 
106
117
  def default_loss_fn(pred, target):
107
- return (pred - target).pow(2).mean(dim = -1).sum()
118
+ return (pred - target).pow(2).mean(dim = -1)
108
119
 
109
120
  class NeuralMemory(Module):
110
121
  def __init__(
@@ -131,17 +142,25 @@ class NeuralMemory(Module):
131
142
  self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
132
143
  self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
133
144
 
134
- self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
145
+ self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
135
146
 
136
147
  # maybe multi-headed
137
148
 
138
149
  dim_head = default(dim_head, dim)
139
150
  dim_inner = dim_head * heads
140
151
 
152
+ self.heads = heads
153
+
141
154
  self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
142
- self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
155
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
143
156
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
144
157
 
158
+ self.retrieve_gate = nn.Sequential(
159
+ LinearNoBias(dim, heads),
160
+ Rearrange('b n h -> b h n 1'),
161
+ nn.Sigmoid()
162
+ ) if heads > 1 else None
163
+
145
164
  # memory mlp
146
165
 
147
166
  if not exists(model):
@@ -159,12 +178,13 @@ class NeuralMemory(Module):
159
178
 
160
179
  # prepare function for per sample gradients from model above, using torch.func
161
180
 
162
- def forward_and_loss(params, inputs, target):
181
+ def forward_and_loss(params, inputs, loss_weights, target):
163
182
  pred = functional_call(self.memory_model, params, inputs)
164
183
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
165
- return loss
184
+ loss = loss * loss_weights
185
+ return loss.sum()
166
186
 
167
- self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
187
+ self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))
168
188
 
169
189
  # queries for retrieving from the model
170
190
 
@@ -190,7 +210,6 @@ class NeuralMemory(Module):
190
210
  )
191
211
 
192
212
  self.to_adaptive_step = nn.Sequential(
193
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
194
213
  LinearNoBias(dim, heads),
195
214
  Rearrange('b n h -> (b h) n')
196
215
  )
@@ -271,9 +290,11 @@ class NeuralMemory(Module):
271
290
 
272
291
  keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
273
292
 
293
+ adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = self.chunk_size)
294
+
274
295
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
275
296
 
276
- grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
297
+ grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
277
298
 
278
299
  grads = TensorDict(grads)
279
300
 
@@ -286,9 +307,9 @@ class NeuralMemory(Module):
286
307
 
287
308
  grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
288
309
 
289
- # multiply gradients with learned adaptive step size
310
+ # negative gradients, adaptive lr already applied as loss weight
290
311
 
291
- surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
312
+ surprises = grads.apply(lambda t: -t)
292
313
 
293
314
  # determine scan function
294
315
 
@@ -356,7 +377,7 @@ class NeuralMemory(Module):
356
377
  past_weights: dict[str, Tensor] | None = None,
357
378
  ):
358
379
  chunk_size = self.chunk_size
359
- seq_len = seq.shape[1]
380
+ batch, seq_len = seq.shape[:2]
360
381
 
361
382
  seq = self.retrieve_norm(seq)
362
383
 
@@ -390,8 +411,6 @@ class NeuralMemory(Module):
390
411
 
391
412
  queries = self.split_heads(queries)
392
413
 
393
- batch = queries.shape[0]
394
-
395
414
  # fetch values from memory model
396
415
 
397
416
  curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
@@ -403,7 +422,14 @@ class NeuralMemory(Module):
403
422
 
404
423
  # reconstitute batch dimension
405
424
 
406
- values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
425
+ values = rearrange(values, '(b h n) c d -> b h (n c) d', b = batch, h = self.heads)
426
+
427
+ values = self.multihead_rmsnorm(values)
428
+
429
+ # maybe gate
430
+
431
+ if exists(self.retrieve_gate):
432
+ values = values * self.retrieve_gate(seq)
407
433
 
408
434
  # maybe merge heads and combine
409
435
 
@@ -411,10 +437,6 @@ class NeuralMemory(Module):
411
437
 
412
438
  values = self.combine_heads(values)
413
439
 
414
- # post norm, somehow could not stabilize this without it, not in paper
415
-
416
- values = self.post_rmsnorm(values)
417
-
418
440
  # restore, pad with empty memory embed
419
441
 
420
442
  empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.22
3
+ Version: 0.0.24
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=3UxRZl_uwQBly11jQAWjfnNzHSoOUKiw-Ux2lXu2ilI,14304
4
+ titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
+ titans_pytorch-0.0.24.dist-info/METADATA,sha256=WGxo4oVx9HCq7LvSH8u_isp1tjxVXb3Ao_GrgjdFzSo,3811
6
+ titans_pytorch-0.0.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.24.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=s7KVqId7hhHw1Ck77FJUXfKC5rpwS4N7Nw2mKFtP8-s,13677
4
- titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
- titans_pytorch-0.0.22.dist-info/METADATA,sha256=acTS0vWh84sGtTIEBzdTsxNi-S4V-Cr8NL4U4Vg0eY0,3811
6
- titans_pytorch-0.0.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.22.dist-info/RECORD,,