titans-pytorch 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -488,11 +488,8 @@ class MemoryAsContextTransformer(Module):
488
488
  neural_memory_model: Module | None = None,
489
489
  neural_memory_kwargs: dict = dict(),
490
490
  neural_memory_layers: tuple[int, ...] | None = None,
491
- aux_kv_recon_loss_weight = 1.,
492
491
  use_flex_attn = False,
493
492
  sliding_window_attn = False,
494
- weight_tie_memory_model = False,
495
- prev_neural_mem_update_for_weights = None
496
493
  ):
497
494
  super().__init__()
498
495
 
@@ -526,16 +523,6 @@ class MemoryAsContextTransformer(Module):
526
523
 
527
524
  neural_memory_layers = default(neural_memory_layers, layers)
528
525
 
529
- # weight tying neural memory model
530
-
531
- maybe_copy = deepcopy if not weight_tie_memory_model else identity
532
-
533
- if weight_tie_memory_model:
534
- assert exists(neural_memory_model), '`neural_memory_model` must be explicitly set'
535
-
536
- self.weight_tie_memory_model = weight_tie_memory_model
537
- self.prev_neural_mem_update_for_weights = default(prev_neural_mem_update_for_weights, weight_tie_memory_model)
538
-
539
526
  # mem, attn, and feedforward layers
540
527
 
541
528
  for layer in layers:
@@ -564,7 +551,7 @@ class MemoryAsContextTransformer(Module):
564
551
  mem = NeuralMemory(
565
552
  dim = dim,
566
553
  chunk_size = self.neural_memory_segment_len,
567
- model = maybe_copy(neural_memory_model),
554
+ model = deepcopy(neural_memory_model),
568
555
  **neural_memory_kwargs
569
556
  )
570
557
 
@@ -585,10 +572,7 @@ class MemoryAsContextTransformer(Module):
585
572
 
586
573
  self.gate_attn_output = neural_mem_gate_attn_output
587
574
 
588
- # auxiliary loss on kv recon
589
-
590
- self.has_aux_kv_recon_loss = aux_kv_recon_loss_weight > 0.
591
- self.aux_kv_recon_loss_weight = aux_kv_recon_loss_weight
575
+ # zero for maybe aux loss + device
592
576
 
593
577
  self.register_buffer('zero', torch.tensor(0.), persistent = False)
594
578
 
@@ -696,7 +680,7 @@ class MemoryAsContextTransformer(Module):
696
680
 
697
681
  # math
698
682
 
699
- batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, prev_neural_mem_update_for_weights = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.prev_neural_mem_update_for_weights
683
+ batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
700
684
 
701
685
  seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
702
686
 
@@ -749,18 +733,10 @@ class MemoryAsContextTransformer(Module):
749
733
  next_kv_caches = []
750
734
  next_neural_mem_caches = []
751
735
 
752
- # weight tied neural memory
753
-
754
- neural_memory_updates = None
755
-
756
736
  # value residual
757
737
 
758
738
  value_residual = None
759
739
 
760
- # aux losses
761
-
762
- kv_recon_losses = self.zero
763
-
764
740
  # when inferencing, only do one token at a time
765
741
 
766
742
  if is_inferencing:
@@ -784,24 +760,16 @@ class MemoryAsContextTransformer(Module):
784
760
  mem_input, add_residual = mem_hyper_conn(x)
785
761
 
786
762
  if not is_inferencing:
787
- (retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
788
- mem_input,
789
- return_aux_kv_loss = True,
790
- prev_layer_updates = neural_memory_updates
763
+ retrieved, next_neural_mem_cache = mem(
764
+ mem_input
791
765
  )
792
766
 
793
- kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
794
-
795
767
  else:
796
768
  (retrieved, next_neural_mem_cache) = mem.forward_inference(
797
769
  mem_input,
798
770
  state = next(neural_mem_caches, None),
799
- prev_layer_updates = neural_memory_updates
800
771
  )
801
772
 
802
- if prev_neural_mem_update_for_weights:
803
- neural_memory_updates = next_neural_mem_cache.updates
804
-
805
773
  if self.gate_attn_output:
806
774
  attn_out_gates = retrieved.sigmoid()
807
775
  else:
@@ -883,14 +851,4 @@ class MemoryAsContextTransformer(Module):
883
851
 
884
852
  return logits, next_cache
885
853
 
886
- ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
887
-
888
- losses = ar_loss
889
-
890
- if self.has_aux_kv_recon_loss:
891
- losses = losses + kv_recon_losses * self.aux_kv_recon_loss_weight
892
-
893
- if not return_loss_breakdown:
894
- return losses
895
-
896
- return losses, (ar_loss, kv_recon_losses)
854
+ return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
@@ -38,7 +38,13 @@ w - num memory network weight parameters
38
38
 
39
39
  LinearNoBias = partial(Linear, bias = False)
40
40
 
41
- NeuralMemCache = namedtuple('NeuralMemCache', ['seq', 'cache_store_segment', 'states', 'updates'])
41
+ NeuralMemCache = namedtuple('NeuralMemCache', [
42
+ 'seq',
43
+ 'weights',
44
+ 'cache_store_segment',
45
+ 'states',
46
+ 'updates',
47
+ ])
42
48
 
43
49
  # functions
44
50
 
@@ -70,6 +76,12 @@ def safe_cat(inputs, dim = -2):
70
76
  def dict_get_shape(td):
71
77
  return {k: v.shape for k, v in td.items()}
72
78
 
79
+ def rearrange_dict_values(td, pattern, **kwargs):
80
+ return td.apply(lambda t: rearrange(t, pattern, **kwargs))
81
+
82
+ def repeat_dict_values(td, pattern, **kwargs):
83
+ return td.apply(lambda t: repeat(t, pattern, **kwargs))
84
+
73
85
  def pair(v):
74
86
  return (v, v) if not isinstance(v, tuple) else v
75
87
 
@@ -106,6 +118,9 @@ def softclamp_max(t, max_value):
106
118
  return ((t / half_max_value).tanh() * half_max_value) + half_max_value
107
119
 
108
120
  def softclamp_grad_norm(t, max_value):
121
+ if t.numel() == 0:
122
+ return t
123
+
109
124
  t, inverse = pack_one_with_inverse(t, 'bn *')
110
125
 
111
126
  norm = t.norm(dim = -1, keepdim = True)
@@ -195,6 +210,12 @@ class AssocScan(Module):
195
210
  ):
196
211
  remove_prev = default(remove_prev, exists(prev))
197
212
 
213
+ inputs, inverse_pack_weight_shape = pack_one_with_inverse(inputs, 'b n *')
214
+ gates, _ = pack_one_with_inverse(gates, 'b n *')
215
+
216
+ if exists(prev):
217
+ prev, _ = pack_one_with_inverse(prev, 'b *')
218
+
198
219
  if exists(prev):
199
220
  inputs, _ = pack([prev, inputs], 'b * d')
200
221
  gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
@@ -205,7 +226,7 @@ class AssocScan(Module):
205
226
  if remove_prev:
206
227
  out = out[:, 1:]
207
228
 
208
- return out
229
+ return inverse_pack_weight_shape(out)
209
230
 
210
231
  from accelerated_scan.triton import scan as triton_scan
211
232
  from accelerated_scan.warp import scan as warp_scan
@@ -226,6 +247,7 @@ class AssocScan(Module):
226
247
 
227
248
  outputs = outputs[..., :seq_len]
228
249
  outputs = rearrange(outputs, 'b d n -> b n d')
250
+
229
251
  return outputs
230
252
 
231
253
  out = accelerate_scan_fn(gates, inputs)
@@ -233,7 +255,7 @@ class AssocScan(Module):
233
255
  if remove_prev:
234
256
  out = out[:, 1:]
235
257
 
236
- return out
258
+ return inverse_pack_weight_shape(out)
237
259
 
238
260
  # main neural memory
239
261
 
@@ -253,7 +275,7 @@ class NeuralMemory(Module):
253
275
  model: Module | None = None,
254
276
  store_memory_loss_fn: Callable = default_loss_fn,
255
277
  adaptive_step_transform: Callable | None = None,
256
- default_step_transform_max_lr = 1e-2,
278
+ default_step_transform_max_lr = 1.,
257
279
  per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
258
280
  max_mem_layer_modulation = 1e1, # max of 10.
259
281
  attn_pool_chunks = False,
@@ -342,14 +364,13 @@ class NeuralMemory(Module):
342
364
  pred = functional_call(self.memory_model, params, inputs)
343
365
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
344
366
  weighted_loss = loss * loss_weights
345
- return weighted_loss.sum(), weighted_loss.mean()
367
+ return weighted_loss.sum()
346
368
 
347
369
  # two functions
348
370
 
349
- grad_fn = grad(forward_and_loss, has_aux = True)
371
+ grad_fn = grad(forward_and_loss)
350
372
 
351
- self.per_sample_grad_fn = vmap(grad_fn, in_dims = (None, 0, 0, 0))
352
- self.per_sample_grad_fn_expanded_weights = vmap(grad_fn, in_dims = (0,) * 4)
373
+ self.per_sample_grad_fn = vmap(grad_fn, in_dims = (0, 0, 0, 0))
353
374
 
354
375
  # queries for retrieving from the model
355
376
 
@@ -417,56 +438,56 @@ class NeuralMemory(Module):
417
438
 
418
439
  self.register_buffer('zero', torch.tensor(0.), persistent = False)
419
440
 
420
- def init_weights(self):
441
+ def init_weights(
442
+ self,
443
+ batch,
444
+ ):
421
445
  weights = TensorDict(dict(self.memory_model.named_parameters()))
446
+ weights = repeat_dict_values(weights, '... -> bh ...', bh = batch * self.heads)
422
447
  return weights
423
448
 
449
+ def init_momentum(
450
+ self,
451
+ batch,
452
+ ):
453
+ weights = TensorDict(dict(self.memory_model.named_parameters()))
454
+ zeros = weights.clone().zero_()
455
+ zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
456
+ return zeros
457
+
424
458
  def store_memories(
425
459
  self,
426
460
  seq,
427
- weights: dict[str, Tensor],
461
+ weights: dict[str, Tensor] | None = None,
428
462
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
429
- prev_layer_updates: dict[str, Tensor] | None = None,
430
- return_aux_kv_loss = False,
431
463
  chunk_size = None,
432
464
  ):
433
- seq_len, heads, chunk_size = seq.shape[-2], self.heads, default(chunk_size, self.store_chunk_size)
434
-
435
- # handle edge case
436
-
437
- if seq_len < chunk_size:
438
- return TensorDict(weights).clone().zero_(), self.zero
439
-
440
- seq = self.store_norm(seq)
465
+ batch, seq_len, heads, chunk_size = *seq.shape[:2], self.heads, default(chunk_size, self.store_chunk_size)
441
466
 
442
467
  # curtail sequence by multiple of the chunk size
443
468
  # only a complete chunk of the sequence provides the memory for the next chunk
444
469
 
445
470
  round_down_seq_len = round_down_multiple(seq_len, chunk_size)
471
+ num_chunks = round_down_seq_len // chunk_size
446
472
 
447
473
  seq, remainder = seq[:, :round_down_seq_len], seq[:, round_down_seq_len:]
448
474
 
449
- # per sample grad function
450
-
451
- per_sample_grad_fn = self.per_sample_grad_fn
452
-
475
+ # init weights if needed
453
476
  # weights of the memory network
454
477
 
478
+ if not exists(weights):
479
+ weights = self.init_weights(batch)
480
+
455
481
  weights = TensorDict(weights)
456
482
 
457
483
  # allow for neural memory of a previous layer to influence surprise of current layer
458
484
 
459
- weights_for_surprise = weights
460
-
461
- if exists(prev_layer_updates):
462
- prev_layer_updates = TensorDict(prev_layer_updates)
463
-
464
- weights_for_surprise = weights_for_surprise + prev_layer_updates
465
-
466
- per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
485
+ weights_for_surprise = repeat_dict_values(weights, 'b ... -> b n ...', n = num_chunks)
467
486
 
468
487
  # derive learned hparams for optimization of memory network
469
488
 
489
+ seq = self.store_norm(seq)
490
+
470
491
  adaptive_lr = self.to_adaptive_step(seq)
471
492
  adaptive_lr = self.adaptive_step_transform(adaptive_lr)
472
493
 
@@ -474,7 +495,7 @@ class NeuralMemory(Module):
474
495
 
475
496
  decay_factor = self.to_decay_factor(chunked_seq).sigmoid()
476
497
 
477
- need_layer_lr_mod = exists(self.to_layer_modulation)
498
+ need_layer_lr_mod = exists(self.to_layer_modulation) and num_chunks > 0
478
499
  has_momentum = exists(self.to_momentum)
479
500
 
480
501
  if has_momentum:
@@ -505,12 +526,11 @@ class NeuralMemory(Module):
505
526
 
506
527
  # flatten batch and time if surprise depends on previous layer memory model
507
528
 
508
- if exists(prev_layer_updates):
509
- weights_for_surprise = weights_for_surprise.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
529
+ weights_for_surprise = rearrange_dict_values(weights_for_surprise, 'b n ... -> (b n) ...')
510
530
 
511
531
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
512
532
 
513
- grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
533
+ grads = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
514
534
 
515
535
  grads = TensorDict(grads)
516
536
 
@@ -521,7 +541,7 @@ class NeuralMemory(Module):
521
541
 
522
542
  # restore batch and sequence dimension
523
543
 
524
- grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch * heads))
544
+ grads = rearrange_dict_values(grads, '(b n) ... -> b n ...', b = batch * heads)
525
545
 
526
546
  # maybe per layer modulation
527
547
 
@@ -535,19 +555,25 @@ class NeuralMemory(Module):
535
555
  # past states
536
556
 
537
557
  if not exists(past_state):
538
- empty_dict = {key: None for key in weights.keys()}
539
-
540
558
  # minibatch_init_weight corresponds to W0 in figure 7 of TTT paper
541
559
 
542
560
  minibatch_init_weight = weights
561
+ init_momentum = self.init_momentum(batch)
543
562
 
544
- if dict_get_shape(weights) == self.init_weight_shape:
545
- minibatch_init_weight = weights.apply(lambda t: repeat(t, '... -> b 1 (...)', b = batch * heads))
546
-
547
- past_state = (minibatch_init_weight, empty_dict)
563
+ past_state = (minibatch_init_weight, init_momentum)
548
564
 
549
565
  past_last_update, past_last_momentum = past_state
550
566
 
567
+ # early return if sequence length less than chunk size
568
+
569
+ if num_chunks == 0:
570
+ updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...')
571
+ next_store_state = NeuralMemCache(seq_len, weights, remainder, past_state, updates)
572
+
573
+ output = (updates, next_store_state)
574
+
575
+ return output
576
+
551
577
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
552
578
 
553
579
  next_momentum = TensorDict() if has_momentum else None
@@ -558,8 +584,6 @@ class NeuralMemory(Module):
558
584
 
559
585
  for (param_name, surprise), (_, last_update), (_, last_momentum) in zip(surprises.items(), past_last_update.items(), past_last_momentum.items()):
560
586
 
561
- surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
562
-
563
587
  update = surprise
564
588
 
565
589
  # derive momentum with associative scan - eq (10)
@@ -571,62 +595,52 @@ class NeuralMemory(Module):
571
595
 
572
596
  # use associative scan again for learned forgetting (weight decay) - eq (13)
573
597
 
574
- update = self.assoc_scan(1. - decay_factor, update, prev = last_update)
598
+ update = self.assoc_scan(1. - decay_factor, update, prev = last_update, remove_prev = False)
575
599
  next_last_update[param_name] = update[:, -1]
576
600
 
577
- updates[param_name] = inverse_pack(update)
601
+ updates[param_name] = update
578
602
 
579
603
  if has_momentum:
580
- next_momentum[param_name] = inverse_pack(momentum)
604
+ next_momentum[param_name] = momentum
581
605
 
582
606
  # determine next state for the storing of memories
583
607
 
584
608
  next_state = (next_last_update, next_last_momentum)
585
609
 
586
- next_store_state = NeuralMemCache(seq_len, remainder, next_state, updates)
610
+ next_store_state = NeuralMemCache(seq_len, weights, remainder, next_state, updates)
587
611
 
588
612
  # returns
589
613
 
590
614
  output = (updates, next_store_state)
591
615
 
592
- if not return_aux_kv_loss:
593
- return output
594
-
595
- return output, aux_kv_recon_loss.mean()
616
+ return output
596
617
 
597
618
  def retrieve_memories(
598
619
  self,
599
620
  seq,
600
621
  past_weights: dict[str, Tensor],
601
622
  chunk_size = None,
602
- prev_layer_updates: dict[str, Tensor] | None = None
603
623
  ):
604
624
  chunk_size = default(chunk_size, self.retrieve_chunk_size)
605
625
  batch, seq_len = seq.shape[:2]
606
626
 
607
627
  seq = self.retrieve_norm(seq)
608
628
 
609
- assert seq_len >= chunk_size, 'must be handled outside of retrieve'
610
-
611
629
  needs_pad = chunk_size > 1
612
630
 
613
- if needs_pad:
614
- seq = pad_at_dim(seq, (1, 0), dim = 1)
615
- seq_len_plus_one = seq.shape[-2]
631
+ seq = pad_at_dim(seq, (1, 0), dim = 1)
632
+ seq_len_plus_one = seq.shape[-2]
616
633
 
617
- next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
634
+ next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
618
635
 
619
- padding = next_seq_len - seq_len_plus_one
620
- seq = pad_at_dim(seq, (0, padding), dim = 1)
636
+ padding = next_seq_len - seq_len_plus_one
637
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
621
638
 
622
639
  # the parameters of the memory model stores the memories of the key / values
623
640
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
624
641
 
625
642
  curr_weights = TensorDict(past_weights)
626
643
 
627
- if exists(prev_layer_updates):
628
- curr_weights = curr_weights + TensorDict(prev_layer_updates)
629
-
630
644
  # sequence Float['b n d'] to queries
631
645
 
632
646
  queries = self.to_queries(seq)
@@ -642,7 +656,7 @@ class NeuralMemory(Module):
642
656
  # fetch values from memory model
643
657
 
644
658
  if dict_get_shape(curr_weights) != self.init_weight_shape:
645
- curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
659
+ curr_weights = rearrange_dict_values(curr_weights, 'b n ... -> (b n) ...')
646
660
 
647
661
  queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
648
662
 
@@ -669,8 +683,7 @@ class NeuralMemory(Module):
669
683
 
670
684
  # restore, pad with empty memory embed
671
685
 
672
- if needs_pad:
673
- values = values[:, 1:(seq_len + 1)]
686
+ values = values[:, 1:(seq_len + 1)]
674
687
 
675
688
  return values
676
689
 
@@ -679,15 +692,14 @@ class NeuralMemory(Module):
679
692
  self,
680
693
  token: Tensor,
681
694
  state = None,
682
- prev_layer_updates: dict[str, Tensor] | None = None,
683
695
  ):
684
696
 
685
697
  # unpack previous state
686
698
 
687
699
  if not exists(state):
688
- state = (0, None, None, None)
700
+ state = (0, None, None, None, None)
689
701
 
690
- seq_index, cache_store_seq, past_states, updates = state
702
+ seq_index, weights, cache_store_seq, past_states, updates = state
691
703
 
692
704
  curr_seq_len = seq_index + 1
693
705
  batch = token.shape[0]
@@ -695,10 +707,6 @@ class NeuralMemory(Module):
695
707
  if token.ndim == 2:
696
708
  token = rearrange(token, 'b d -> b 1 d')
697
709
 
698
- # get memory model weights
699
-
700
- weights = self.init_weights()
701
-
702
710
  # increment the sequence cache which is at most the chunk size
703
711
 
704
712
  cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
@@ -708,7 +716,7 @@ class NeuralMemory(Module):
708
716
  if curr_seq_len < self.chunk_size:
709
717
  retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
710
718
 
711
- output = retrieve, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
719
+ output = retrieve, NeuralMemCache(curr_seq_len, weights, cache_store_seq, past_states, updates)
712
720
 
713
721
  return output
714
722
 
@@ -719,21 +727,16 @@ class NeuralMemory(Module):
719
727
 
720
728
  if not exists(updates):
721
729
  updates = weights.clone().zero_()
722
- updates = updates.apply(lambda t: repeat(t, '... -> b 1 ...', b = batch))
730
+ updates = repeat_dict_values(updates, '... -> b 1 ...', b = batch)
723
731
  else:
724
732
  updates = updates.apply(lambda t: t[:, -1:])
725
733
 
726
- if exists(prev_layer_updates):
727
- prev_layer_updates = TensorDict(prev_layer_updates)
728
- prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
729
-
730
734
  if store_seq_cache_len == self.chunk_size:
731
735
 
732
736
  next_updates, store_state = self.store_memories(
733
737
  cache_store_seq,
734
738
  weights,
735
739
  past_state = past_states,
736
- prev_layer_updates = prev_layer_updates,
737
740
  )
738
741
 
739
742
  updates = next_updates
@@ -742,11 +745,11 @@ class NeuralMemory(Module):
742
745
 
743
746
  # retrieve
744
747
 
745
- retrieved = self.retrieve_memories(token, weights, chunk_size = 1)
748
+ retrieved = self.retrieve_memories(token, updates, chunk_size = 1)
746
749
 
747
750
  # next state tuple
748
751
 
749
- next_store_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
752
+ next_store_state = NeuralMemCache(curr_seq_len, weights, cache_store_seq, next_states, updates)
750
753
 
751
754
  return retrieved, next_store_state
752
755
 
@@ -756,62 +759,30 @@ class NeuralMemory(Module):
756
759
  store_seq = None,
757
760
  mem_model_weights: dict[str, Tensor] | None = None,
758
761
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
759
- return_aux_kv_loss = False,
760
762
  chunk_size = None,
761
763
  store_chunk_size = None,
762
764
  return_next_state = False,
763
- prev_layer_updates: dict[str, Tensor] | None = None
764
765
  ):
765
766
  batch, seq_len = seq.shape[:2]
766
767
 
767
- if not exists(mem_model_weights):
768
- mem_model_weights = self.init_weights()
769
-
770
- if seq_len < self.retrieve_chunk_size:
771
- retrieved = self.retrieve_memories(seq, mem_model_weights, chunk_size = 1)
772
-
773
- next_store_state = NeuralMemCache(seq_len, seq, None, None)
774
-
775
- out = (retrieved, next_store_state)
776
-
777
- if not return_aux_kv_loss:
778
- return out
779
-
780
- return out, self.zero
781
-
782
768
  # store
783
769
 
784
770
  store_seq = default(store_seq, seq)
785
771
 
786
- (updates, next_store_state), aux_kv_recon_loss = self.store_memories(
772
+ updates, next_store_state = self.store_memories(
787
773
  store_seq,
788
774
  mem_model_weights,
789
775
  chunk_size = store_chunk_size,
790
- prev_layer_updates = prev_layer_updates,
791
- return_aux_kv_loss = True
792
776
  )
793
777
 
794
778
  # retrieve
795
779
 
796
- retrieve_chunk_size = default(chunk_size, self.retrieve_chunk_size)
797
-
798
- if retrieve_chunk_size != 1:
799
- if exists(prev_layer_updates):
800
- prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
801
-
802
- updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
803
-
804
-
805
780
  retrieved = self.retrieve_memories(
806
781
  seq,
807
- mem_model_weights,
782
+ updates,
808
783
  chunk_size = chunk_size,
809
- prev_layer_updates = prev_layer_updates
810
784
  )
811
785
 
812
786
  output = (retrieved, next_store_state)
813
787
 
814
- if not return_aux_kv_loss:
815
- return output
816
-
817
- return output, aux_kv_recon_loss
788
+ return output
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.8
3
+ Version: 0.2.10
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=dmS37yBN0j9OqoMCsojuIPfT1EXLN8ackRdZwPb8xDY,24463
4
+ titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
+ titans_pytorch/neural_memory.py,sha256=kc-cV7dK3WhdqRfOCrPW91nA0F56jUK94TE1irckQ34,23487
6
+ titans_pytorch-0.2.10.dist-info/METADATA,sha256=k7u9eQDNAWG3QqzGqhcdN21D6LYWWRWdd5wZFb560q0,6812
7
+ titans_pytorch-0.2.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.10.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
4
- titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
- titans_pytorch/neural_memory.py,sha256=WAeR-nOpy1XbBP590By1-tCgirulqPbFGut4H1B77-g,24910
6
- titans_pytorch-0.2.8.dist-info/METADATA,sha256=4fLUv34KqloeYMWjHBUmp-3iEw0Xq47fjRrwlkyTEsM,6811
7
- titans_pytorch-0.2.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.2.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.2.8.dist-info/RECORD,,