titans-pytorch 0.0.45__tar.gz → 0.0.47__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.45
3
+ Version: 0.0.47
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.45"
3
+ version = "0.0.47"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -311,7 +311,7 @@ class MemoryAsContextTransformer(Module):
311
311
 
312
312
  pos_emb = self.axial_pos_emb((neural_mem_windows, neural_mem_segment_len), flatten = True)
313
313
 
314
- pos_emb = pos_emb[:seq_len_with_mem]
314
+ x = x + pos_emb[:seq_len_with_mem]
315
315
 
316
316
  # value residual
317
317
 
@@ -324,7 +324,7 @@ class MemoryAsContextTransformer(Module):
324
324
  for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
325
325
 
326
326
  if exists(maybe_neural_mem):
327
- x = maybe_neural_mem(x, pos_emb = pos_emb)
327
+ x = maybe_neural_mem(x)
328
328
 
329
329
  x, values = attn(x, value_residual = value_residual)
330
330
 
@@ -217,10 +217,10 @@ class NeuralMemory(Module):
217
217
  def forward_and_loss(params, inputs, loss_weights, target):
218
218
  pred = functional_call(self.memory_model, params, inputs)
219
219
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
220
- loss = loss * loss_weights
221
- return loss.sum()
220
+ weighted_loss = loss * loss_weights
221
+ return weighted_loss.sum(), loss.mean()
222
222
 
223
- self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))
223
+ self.per_sample_grad_fn = vmap(grad(forward_and_loss, has_aux = True), in_dims = (None, 0, 0, 0))
224
224
 
225
225
  # queries for retrieving from the model
226
226
 
@@ -282,7 +282,8 @@ class NeuralMemory(Module):
282
282
  def store_memories(
283
283
  self,
284
284
  seq,
285
- past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
285
+ past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
286
+ return_aux_kv_loss = False
286
287
  ):
287
288
 
288
289
  seq = self.store_norm(seq)
@@ -330,7 +331,7 @@ class NeuralMemory(Module):
330
331
 
331
332
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
332
333
 
333
- grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
334
+ grads, aux_kv_recon_loss = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
334
335
 
335
336
  grads = TensorDict(grads)
336
337
 
@@ -405,7 +406,10 @@ class NeuralMemory(Module):
405
406
 
406
407
  next_state = (curr_weights + last_update, next_momentum)
407
408
 
408
- return updates, next_state
409
+ if not return_aux_kv_loss:
410
+ return updates, next_state
411
+
412
+ return updates, next_state, aux_kv_recon_loss
409
413
 
410
414
  def retrieve_memories(
411
415
  self,
@@ -484,14 +488,10 @@ class NeuralMemory(Module):
484
488
  seq,
485
489
  store_seq = None,
486
490
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
487
- return_next_memories = False,
488
- pos_emb: Tensor | None = None
491
+ return_aux_kv_loss = False
489
492
  ):
490
493
  batch, seq_len = seq.shape[:2]
491
494
 
492
- if exists(pos_emb):
493
- seq = seq + pos_emb
494
-
495
495
  if seq_len < self.chunk_size:
496
496
  return self.init_empty_memory_embed(batch, seq_len)
497
497
 
@@ -503,13 +503,13 @@ class NeuralMemory(Module):
503
503
 
504
504
  store_seq = default(store_seq, seq)
505
505
 
506
- updates, next_memories = self.store_memories(store_seq, past_state)
506
+ updates, next_memories, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
507
507
 
508
508
  past_weights, _ = past_state
509
509
 
510
510
  retrieved = self.retrieve_memories(seq, past_weights + updates)
511
511
 
512
- if not return_next_memories:
512
+ if not return_aux_kv_loss:
513
513
  return retrieved
514
514
 
515
- return retrieved, next_memories
515
+ return retrieved, aux_kv_recon_loss
File without changes