titans-pytorch 0.0.46__tar.gz → 0.0.49__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.46
3
+ Version: 0.0.49
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.46"
3
+ version = "0.0.49"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -95,7 +95,7 @@ class SegmentedAttention(Module):
95
95
  dim_head = 64,
96
96
  heads = 8,
97
97
  accept_value_residual = False,
98
- attend_kwargs: dict = dict()
98
+ attend_kwargs: dict = dict(),
99
99
  ):
100
100
  super().__init__()
101
101
  self.norm = nn.RMSNorm(dim)
@@ -201,6 +201,7 @@ class MemoryAsContextTransformer(Module):
201
201
  num_residual_streams = 4,
202
202
  neural_memory_kwargs: dict = dict(),
203
203
  neural_memory_layers: tuple[int, ...] | None = None,
204
+ aux_kv_recon_loss_weight = 0.
204
205
  ):
205
206
  super().__init__()
206
207
 
@@ -276,10 +277,18 @@ class MemoryAsContextTransformer(Module):
276
277
 
277
278
  self.to_logits = LinearNoBias(dim, num_tokens)
278
279
 
280
+ # auxiliary loss on kv recon
281
+
282
+ self.has_aux_kv_recon_loss = aux_kv_recon_loss_weight > 0.
283
+ self.aux_kv_recon_loss_weight = aux_kv_recon_loss_weight
284
+
285
+ self.register_buffer('zero', torch.tensor(0.), persistent = False)
286
+
279
287
  def forward(
280
288
  self,
281
289
  x,
282
- return_loss = False
290
+ return_loss = False,
291
+ return_loss_breakdown = False
283
292
  ):
284
293
 
285
294
  if return_loss:
@@ -317,6 +326,10 @@ class MemoryAsContextTransformer(Module):
317
326
 
318
327
  value_residual = None
319
328
 
329
+ # aux losses
330
+
331
+ kv_recon_losses = self.zero
332
+
320
333
  # expand and reduce streams for hyper connections
321
334
 
322
335
  x = self.expand_streams(x)
@@ -324,7 +337,8 @@ class MemoryAsContextTransformer(Module):
324
337
  for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
325
338
 
326
339
  if exists(maybe_neural_mem):
327
- x = maybe_neural_mem(x)
340
+ x, aux_kv_loss = maybe_neural_mem(x, return_aux_kv_loss = True)
341
+ kv_recon_losses = kv_recon_losses + aux_kv_loss
328
342
 
329
343
  x, values = attn(x, value_residual = value_residual)
330
344
 
@@ -351,4 +365,14 @@ class MemoryAsContextTransformer(Module):
351
365
  if not return_loss:
352
366
  return logits
353
367
 
354
- return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
368
+ ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
369
+
370
+ losses = ar_loss
371
+
372
+ if self.has_aux_kv_recon_loss:
373
+ losses = losses + kv_recon_losses * self.aux_kv_recon_loss_weight
374
+
375
+ if not return_loss_breakdown:
376
+ return losses
377
+
378
+ return losses, (ar_loss, kv_recon_losses)
@@ -217,10 +217,10 @@ class NeuralMemory(Module):
217
217
  def forward_and_loss(params, inputs, loss_weights, target):
218
218
  pred = functional_call(self.memory_model, params, inputs)
219
219
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
220
- loss = loss * loss_weights
221
- return loss.sum()
220
+ weighted_loss = loss * loss_weights
221
+ return weighted_loss.sum(), weighted_loss.mean()
222
222
 
223
- self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))
223
+ self.per_sample_grad_fn = vmap(grad(forward_and_loss, has_aux = True), in_dims = (None, 0, 0, 0))
224
224
 
225
225
  # queries for retrieving from the model
226
226
 
@@ -282,7 +282,8 @@ class NeuralMemory(Module):
282
282
  def store_memories(
283
283
  self,
284
284
  seq,
285
- past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
285
+ past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
286
+ return_aux_kv_loss = False
286
287
  ):
287
288
 
288
289
  seq = self.store_norm(seq)
@@ -330,7 +331,7 @@ class NeuralMemory(Module):
330
331
 
331
332
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
332
333
 
333
- grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
334
+ grads, aux_kv_recon_loss = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
334
335
 
335
336
  grads = TensorDict(grads)
336
337
 
@@ -405,7 +406,10 @@ class NeuralMemory(Module):
405
406
 
406
407
  next_state = (curr_weights + last_update, next_momentum)
407
408
 
408
- return updates, next_state
409
+ if not return_aux_kv_loss:
410
+ return updates, next_state
411
+
412
+ return updates, next_state, aux_kv_recon_loss.mean()
409
413
 
410
414
  def retrieve_memories(
411
415
  self,
@@ -484,7 +488,7 @@ class NeuralMemory(Module):
484
488
  seq,
485
489
  store_seq = None,
486
490
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
487
- return_next_memories = False
491
+ return_aux_kv_loss = False
488
492
  ):
489
493
  batch, seq_len = seq.shape[:2]
490
494
 
@@ -499,13 +503,13 @@ class NeuralMemory(Module):
499
503
 
500
504
  store_seq = default(store_seq, seq)
501
505
 
502
- updates, next_memories = self.store_memories(store_seq, past_state)
506
+ updates, next_memories, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
503
507
 
504
508
  past_weights, _ = past_state
505
509
 
506
510
  retrieved = self.retrieve_memories(seq, past_weights + updates)
507
511
 
508
- if not return_next_memories:
512
+ if not return_aux_kv_loss:
509
513
  return retrieved
510
514
 
511
- return retrieved, next_memories
515
+ return retrieved, aux_kv_recon_loss
@@ -24,13 +24,14 @@ SHOULD_GENERATE = False
24
24
  SEQ_LEN = 512
25
25
 
26
26
  PROJECT_NAME = 'titans-mac-transformer'
27
- WANDB_ONLINE = False # turn this on to pipe experiment to cloud
27
+ WANDB_ONLINE = True # turn this on to pipe experiment to cloud
28
28
  NEURAL_MEMORY_DEPTH = 2
29
29
  NUM_PERSIST_MEM = 4
30
30
  NUM_LONGTERM_MEM = 4
31
31
  NEURAL_MEM_LAYERS = (2, 4)
32
32
  WINDOW_SIZE = 32
33
- RUN_NAME = f'mac - {NUM_LONGTERM_MEM} longterm mems, layers {NEURAL_MEM_LAYERS}'
33
+ KV_RECON_LOSS_WEIGHT = 0.1
34
+ RUN_NAME = f'mac - {NUM_LONGTERM_MEM} longterm mems, layers {NEURAL_MEM_LAYERS} - kv recon loss 0.1'
34
35
 
35
36
  # wandb experiment tracker
36
37
 
@@ -63,6 +64,7 @@ model = MemoryAsContextTransformer(
63
64
  num_longterm_mem_tokens = NUM_LONGTERM_MEM,
64
65
  neural_memory_layers = NEURAL_MEM_LAYERS,
65
66
  neural_memory_segment_len = WINDOW_SIZE // 2,
67
+ aux_kv_recon_loss_weight = KV_RECON_LOSS_WEIGHT,
66
68
  neural_memory_kwargs = dict(
67
69
  dim_head = 64,
68
70
  heads = 4,
@@ -108,20 +110,20 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
108
110
  model.train()
109
111
 
110
112
  for __ in range(GRADIENT_ACCUMULATE_EVERY):
111
- loss = model(next(train_loader), return_loss = True)
113
+ loss, (ar_loss, kv_recon_losses) = model(next(train_loader), return_loss = True, return_loss_breakdown = True)
112
114
  loss.backward()
113
115
 
114
- print(f'training loss: {loss.item()}')
116
+ print(f'training loss: {ar_loss.item()}')
115
117
  torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
116
118
  optim.step()
117
119
  optim.zero_grad()
118
- wandb.log(dict(loss = loss.item()))
120
+ wandb.log(dict(loss = ar_loss.item()))
119
121
 
120
122
  if i % VALIDATE_EVERY == 0:
121
123
  model.eval()
122
124
  with torch.no_grad():
123
- loss = model(next(val_loader), return_loss = True)
124
- print(f'validation loss: {loss.item()}')
125
+ loss, (ar_loss, _) = model(next(val_loader), return_loss = True, return_loss_breakdown = True)
126
+ print(f'validation loss: {ar_loss.item()}')
125
127
 
126
128
  if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
127
129
  model.eval()
File without changes