titans-pytorch 0.2.0__tar.gz → 0.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.0
3
+ Version: 0.2.4
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.2.0"
3
+ version = "0.2.4"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -491,7 +491,8 @@ class MemoryAsContextTransformer(Module):
491
491
  aux_kv_recon_loss_weight = 0.,
492
492
  use_flex_attn = False,
493
493
  sliding_window_attn = False,
494
- weight_tie_memory_model = False
494
+ weight_tie_memory_model = False,
495
+ prev_neural_mem_update_for_weights = None
495
496
  ):
496
497
  super().__init__()
497
498
 
@@ -533,11 +534,7 @@ class MemoryAsContextTransformer(Module):
533
534
  assert exists(neural_memory_model), '`neural_memory_model` must be explicitly set'
534
535
 
535
536
  self.weight_tie_memory_model = weight_tie_memory_model
536
-
537
- # value residual learning for neural memory
538
-
539
- is_first_mem = True
540
- self.mem_add_value_residual = neural_memory_add_value_residual
537
+ self.prev_neural_mem_update_for_weights = default(prev_neural_mem_update_for_weights, weight_tie_memory_model)
541
538
 
542
539
  # mem, attn, and feedforward layers
543
540
 
@@ -568,12 +565,9 @@ class MemoryAsContextTransformer(Module):
568
565
  dim = dim,
569
566
  chunk_size = self.neural_memory_segment_len,
570
567
  model = maybe_copy(neural_memory_model),
571
- accept_value_residual = not is_first_mem and neural_memory_add_value_residual,
572
568
  **neural_memory_kwargs
573
569
  )
574
570
 
575
- is_first_mem = False
576
-
577
571
  ff = FeedForward(dim = dim, mult = ff_mult)
578
572
 
579
573
  self.layers.append(ModuleList([
@@ -702,7 +696,7 @@ class MemoryAsContextTransformer(Module):
702
696
 
703
697
  # math
704
698
 
705
- batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, weight_tie_memory_model = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.weight_tie_memory_model
699
+ batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, prev_neural_mem_update_for_weights = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.prev_neural_mem_update_for_weights
706
700
 
707
701
  seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
708
702
 
@@ -763,8 +757,6 @@ class MemoryAsContextTransformer(Module):
763
757
 
764
758
  value_residual = None
765
759
 
766
- mem_value_residual = None
767
-
768
760
  # aux losses
769
761
 
770
762
  kv_recon_losses = self.zero
@@ -792,29 +784,22 @@ class MemoryAsContextTransformer(Module):
792
784
  mem_input, add_residual = mem_hyper_conn(x)
793
785
 
794
786
  if not is_inferencing:
795
- (retrieved, next_neural_mem_cache, next_mem_value_residual), mem_kv_aux_loss = mem(
787
+ (retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
796
788
  mem_input,
797
789
  return_aux_kv_loss = True,
798
- return_values = True,
799
- value_residual = mem_value_residual,
800
790
  prev_layer_updates = neural_memory_updates
801
791
  )
802
792
 
803
793
  kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
804
794
 
805
795
  else:
806
- (retrieved, next_neural_mem_cache, next_mem_value_residual) = mem.forward_inference(
796
+ (retrieved, next_neural_mem_cache) = mem.forward_inference(
807
797
  mem_input,
808
798
  state = next(neural_mem_caches, None),
809
- return_values = True,
810
- value_residual = mem_value_residual,
811
799
  prev_layer_updates = neural_memory_updates
812
800
  )
813
801
 
814
- if self.mem_add_value_residual:
815
- mem_value_residual = next_mem_value_residual
816
-
817
- if weight_tie_memory_model:
802
+ if prev_neural_mem_update_for_weights:
818
803
  neural_memory_updates = next_neural_mem_cache.updates
819
804
 
820
805
  if self.gate_attn_output:
@@ -67,6 +67,9 @@ def safe_cat(inputs, dim = -2):
67
67
  def identity(t):
68
68
  return t
69
69
 
70
+ def dict_get_shape(td):
71
+ return {k: v.shape for k, v in td.items()}
72
+
70
73
  def pair(v):
71
74
  return (v, v) if not isinstance(v, tuple) else v
72
75
 
@@ -258,7 +261,6 @@ class NeuralMemory(Module):
258
261
  pre_rmsnorm = True,
259
262
  post_rmsnorm = True,
260
263
  qk_rmsnorm = False,
261
- accept_value_residual = False,
262
264
  max_grad_norm: float | None = None,
263
265
  use_accelerated_scan = False,
264
266
  activation: Module | None = None,
@@ -315,6 +317,8 @@ class NeuralMemory(Module):
315
317
 
316
318
  self.num_memory_parameter_tensors = len(set(model.parameters()))
317
319
 
320
+ self.init_weight_shape = dict_get_shape(dict(model.named_parameters()))
321
+
318
322
  # the chunk size within the paper where adaptive step, momentum, weight decay are shared
319
323
 
320
324
  self.chunk_size = chunk_size
@@ -343,19 +347,6 @@ class NeuralMemory(Module):
343
347
  self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
344
348
  self.store_memory_loss_fn = store_memory_loss_fn
345
349
 
346
- # value residual learning
347
-
348
- self.learned_value_residual = Sequential(
349
- LinearNoBias(dim, heads),
350
- Rearrange('b n h -> b h n 1'),
351
- nn.Sigmoid()
352
- ) if accept_value_residual else None
353
-
354
- # empty memory embed
355
-
356
- self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
357
- nn.init.normal_(self.empty_memory_embed, std = 0.02)
358
-
359
350
  # `chunk_size` refers to chunk size used for storing to memory model weights
360
351
 
361
352
  chunk_size = self.store_chunk_size
@@ -417,9 +408,6 @@ class NeuralMemory(Module):
417
408
  weights = TensorDict(dict(self.memory_model.named_parameters()))
418
409
  return weights
419
410
 
420
- def init_empty_memory_embed(self, batch, seq_len):
421
- return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
422
-
423
411
  def store_memories(
424
412
  self,
425
413
  seq,
@@ -428,10 +416,7 @@ class NeuralMemory(Module):
428
416
  prev_layer_updates: dict[str, Tensor] | None = None,
429
417
  return_aux_kv_loss = False,
430
418
  chunk_size = None,
431
- value_residual = None
432
419
  ):
433
- assert xnor(exists(value_residual), exists(self.learned_value_residual))
434
-
435
420
  seq_len, heads, chunk_size = seq.shape[-2], self.heads, default(chunk_size, self.store_chunk_size)
436
421
 
437
422
  # handle edge case
@@ -446,7 +431,7 @@ class NeuralMemory(Module):
446
431
 
447
432
  round_down_seq_len = round_down_multiple(seq_len, chunk_size)
448
433
 
449
- seq = seq[:, :round_down_seq_len]
434
+ seq, remainder = seq[:, :round_down_seq_len], seq[:, round_down_seq_len:]
450
435
 
451
436
  # per sample grad function
452
437
 
@@ -499,14 +484,6 @@ class NeuralMemory(Module):
499
484
 
500
485
  keys = self.k_norm(keys)
501
486
 
502
- # maybe value residual learning
503
-
504
- orig_values = values
505
-
506
- if exists(self.learned_value_residual):
507
- mix = self.learned_value_residual(seq)
508
- values = values.lerp(value_residual, mix)
509
-
510
487
  # take care of chunking
511
488
 
512
489
  keys, values = tuple(rearrange(t, 'b h (n c) d -> (b h n) c d', c = chunk_size) for t in (keys, values))
@@ -581,13 +558,15 @@ class NeuralMemory(Module):
581
558
  if has_momentum:
582
559
  next_momentum[param_name] = inverse_pack(momentum)
583
560
 
584
- # compute next states for inference, or titans-xl like training
561
+ # determine next state for the storing of memories
585
562
 
586
563
  next_state = (next_last_update, next_last_momentum)
587
564
 
565
+ next_store_state = NeuralMemCache(seq_len, remainder, next_state, updates)
566
+
588
567
  # returns
589
568
 
590
- output = (updates, next_state, orig_values)
569
+ output = (updates, next_store_state)
591
570
 
592
571
  if not return_aux_kv_loss:
593
572
  return output
@@ -606,16 +585,18 @@ class NeuralMemory(Module):
606
585
 
607
586
  seq = self.retrieve_norm(seq)
608
587
 
609
- if seq_len < chunk_size:
610
- return self.init_empty_memory_embed(batch, seq_len)
588
+ assert seq_len >= chunk_size, 'must be handled outside of retrieve'
589
+
590
+ needs_pad = chunk_size > 1
611
591
 
612
- seq = seq[:, (chunk_size - 1):]
613
- curtailed_seq_len = seq.shape[-2]
592
+ if needs_pad:
593
+ seq = pad_at_dim(seq, (1, 0), dim = 1)
594
+ seq_len_plus_one = seq.shape[-2]
614
595
 
615
- next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
596
+ next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
616
597
 
617
- padding = next_seq_len - curtailed_seq_len
618
- seq = pad_at_dim(seq, (0, padding), dim = 1)
598
+ padding = next_seq_len - seq_len_plus_one
599
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
619
600
 
620
601
  # the parameters of the memory model stores the memories of the key / values
621
602
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
@@ -639,7 +620,9 @@ class NeuralMemory(Module):
639
620
 
640
621
  # fetch values from memory model
641
622
 
642
- curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
623
+ if dict_get_shape(curr_weights) != self.init_weight_shape:
624
+ curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
625
+
643
626
  queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
644
627
 
645
628
  # forward functional call
@@ -665,10 +648,10 @@ class NeuralMemory(Module):
665
648
 
666
649
  # restore, pad with empty memory embed
667
650
 
668
- empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
669
- values = torch.cat((empty_memory_embeds, values), dim = -2)
651
+ if needs_pad:
652
+ values = values[:, 1:(seq_len + 1)]
670
653
 
671
- return values[:, :seq_len]
654
+ return values
672
655
 
673
656
  @torch.no_grad()
674
657
  def forward_inference(
@@ -676,8 +659,6 @@ class NeuralMemory(Module):
676
659
  token: Tensor,
677
660
  state = None,
678
661
  prev_layer_updates: dict[str, Tensor] | None = None,
679
- return_values = False,
680
- value_residual = None,
681
662
  ):
682
663
 
683
664
  # unpack previous state
@@ -704,12 +685,9 @@ class NeuralMemory(Module):
704
685
  # early return empty memory, when no memories are stored for steps < first chunk size
705
686
 
706
687
  if curr_seq_len < self.chunk_size:
707
- empty_mem = self.init_empty_memory_embed(batch, 1)
708
-
709
- output = empty_mem, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
688
+ retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
710
689
 
711
- if return_values:
712
- output = (*output, self.zero)
690
+ output = retrieve, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
713
691
 
714
692
  return output
715
693
 
@@ -728,20 +706,18 @@ class NeuralMemory(Module):
728
706
  prev_layer_updates = TensorDict(prev_layer_updates)
729
707
  prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
730
708
 
731
- values = None
732
-
733
709
  if store_seq_cache_len == self.chunk_size:
734
710
 
735
- next_updates, next_states, values = self.store_memories(
711
+ next_updates, store_state = self.store_memories(
736
712
  cache_store_seq,
737
713
  weights,
738
714
  past_state = past_states,
739
715
  prev_layer_updates = prev_layer_updates,
740
- value_residual = value_residual
741
716
  )
742
717
 
743
718
  updates = next_updates
744
719
  cache_store_seq = None
720
+ next_states = store_state.states
745
721
 
746
722
  # retrieve
747
723
 
@@ -749,14 +725,9 @@ class NeuralMemory(Module):
749
725
 
750
726
  # next state tuple
751
727
 
752
- next_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
753
-
754
- output = (retrieved, next_state)
755
-
756
- if return_values:
757
- output = (*output, values)
728
+ next_store_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
758
729
 
759
- return output
730
+ return retrieved, next_store_state
760
731
 
761
732
  def forward(
762
733
  self,
@@ -767,50 +738,45 @@ class NeuralMemory(Module):
767
738
  return_aux_kv_loss = False,
768
739
  chunk_size = None,
769
740
  store_chunk_size = None,
770
- return_values = False,
771
- value_residual = None,
772
741
  return_next_state = False,
773
742
  prev_layer_updates: dict[str, Tensor] | None = None
774
743
  ):
775
744
  batch, seq_len = seq.shape[:2]
776
745
 
746
+ if not exists(mem_model_weights):
747
+ mem_model_weights = self.init_weights()
748
+
777
749
  if seq_len < self.retrieve_chunk_size:
778
- out = self.init_empty_memory_embed(batch, seq_len)
750
+ retrieved = self.retrieve_memories(seq, mem_model_weights, chunk_size = 1)
779
751
 
780
752
  next_store_state = NeuralMemCache(seq_len, seq, None, None)
781
753
 
782
- out = (out, next_store_state)
783
-
784
- if return_values:
785
- out = (*out, self.zero)
754
+ out = (retrieved, next_store_state)
786
755
 
787
756
  if not return_aux_kv_loss:
788
757
  return out
789
758
 
790
759
  return out, self.zero
791
760
 
792
- if not exists(mem_model_weights):
793
- mem_model_weights = self.init_weights()
794
-
795
761
  # store
796
762
 
797
763
  store_seq = default(store_seq, seq)
798
764
 
799
- store_seq_len = store_seq.shape[-2]
800
- store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
801
- remainder = store_seq_len % store_chunk_size
802
-
803
- (updates, next_state, values), aux_kv_recon_loss = self.store_memories(
765
+ (updates, next_store_state), aux_kv_recon_loss = self.store_memories(
804
766
  store_seq,
805
767
  mem_model_weights,
806
768
  chunk_size = store_chunk_size,
807
769
  prev_layer_updates = prev_layer_updates,
808
- value_residual = value_residual,
809
770
  return_aux_kv_loss = True
810
771
  )
811
772
 
812
773
  # retrieve
813
774
 
775
+ if exists(prev_layer_updates):
776
+ prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
777
+
778
+ updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
779
+
814
780
  retrieved = self.retrieve_memories(
815
781
  seq,
816
782
  mem_model_weights + updates,
@@ -818,21 +784,8 @@ class NeuralMemory(Module):
818
784
  prev_layer_updates = prev_layer_updates
819
785
  )
820
786
 
821
- # determine state for the storing of memories
822
- # for transformer-xl like training with neural memory as well as inferencing with initial prompt
823
-
824
- cache_store_seq = None
825
-
826
- if remainder > 0:
827
- cache_store_seq = store_seq[:, -remainder:]
828
-
829
- next_store_state = NeuralMemCache(seq_len, cache_store_seq, next_state, updates)
830
-
831
787
  output = (retrieved, next_store_state)
832
788
 
833
- if return_values:
834
- output = (*output, values)
835
-
836
789
  if not return_aux_kv_loss:
837
790
  return output
838
791
 
@@ -34,7 +34,6 @@ NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural memory,
34
34
  NEURAL_MEM_GATE_ATTN_OUTPUT = False
35
35
  NEURAL_MEM_MOMENTUM = True
36
36
  NEURAL_MEM_QK_NORM = False
37
- NEURAL_MEM_ADD_VALUE_RESIDUAL = False
38
37
  WINDOW_SIZE = 32
39
38
  NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
40
39
  SLIDING_WINDOWS = True
@@ -91,7 +90,6 @@ model = MemoryAsContextTransformer(
91
90
  use_flex_attn = USE_FLEX_ATTN,
92
91
  sliding_window_attn = SLIDING_WINDOWS,
93
92
  weight_tie_memory_model = WEIGHT_TIE_MEMORY_MODEL,
94
- neural_memory_add_value_residual = NEURAL_MEM_ADD_VALUE_RESIDUAL,
95
93
  neural_memory_model = MemoryMLP(
96
94
  dim = 64,
97
95
  depth = NEURAL_MEMORY_DEPTH
@@ -164,6 +162,6 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
164
162
  prime = decode_tokens(inp)
165
163
  print(f'%s \n\n %s', (prime, '*' * 100))
166
164
 
167
- sample = model.sample(inp[None, ...], GENERATE_LENGTH, use_cache = True)
165
+ sample = model.sample(inp[None, ...], GENERATE_LENGTH, use_cache = USE_FAST_INFERENCE)
168
166
  output_str = decode_tokens(sample[0])
169
167
  print(output_str)
File without changes
File without changes
File without changes
File without changes