titans-pytorch 0.2.1__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -488,7 +488,7 @@ class MemoryAsContextTransformer(Module):
488
488
  neural_memory_model: Module | None = None,
489
489
  neural_memory_kwargs: dict = dict(),
490
490
  neural_memory_layers: tuple[int, ...] | None = None,
491
- aux_kv_recon_loss_weight = 0.,
491
+ aux_kv_recon_loss_weight = 1.,
492
492
  use_flex_attn = False,
493
493
  sliding_window_attn = False,
494
494
  weight_tie_memory_model = False,
@@ -536,11 +536,6 @@ class MemoryAsContextTransformer(Module):
536
536
  self.weight_tie_memory_model = weight_tie_memory_model
537
537
  self.prev_neural_mem_update_for_weights = default(prev_neural_mem_update_for_weights, weight_tie_memory_model)
538
538
 
539
- # value residual learning for neural memory
540
-
541
- is_first_mem = True
542
- self.mem_add_value_residual = neural_memory_add_value_residual
543
-
544
539
  # mem, attn, and feedforward layers
545
540
 
546
541
  for layer in layers:
@@ -570,12 +565,9 @@ class MemoryAsContextTransformer(Module):
570
565
  dim = dim,
571
566
  chunk_size = self.neural_memory_segment_len,
572
567
  model = maybe_copy(neural_memory_model),
573
- accept_value_residual = not is_first_mem and neural_memory_add_value_residual,
574
568
  **neural_memory_kwargs
575
569
  )
576
570
 
577
- is_first_mem = False
578
-
579
571
  ff = FeedForward(dim = dim, mult = ff_mult)
580
572
 
581
573
  self.layers.append(ModuleList([
@@ -765,8 +757,6 @@ class MemoryAsContextTransformer(Module):
765
757
 
766
758
  value_residual = None
767
759
 
768
- mem_value_residual = None
769
-
770
760
  # aux losses
771
761
 
772
762
  kv_recon_losses = self.zero
@@ -794,28 +784,21 @@ class MemoryAsContextTransformer(Module):
794
784
  mem_input, add_residual = mem_hyper_conn(x)
795
785
 
796
786
  if not is_inferencing:
797
- (retrieved, next_neural_mem_cache, next_mem_value_residual), mem_kv_aux_loss = mem(
787
+ (retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
798
788
  mem_input,
799
789
  return_aux_kv_loss = True,
800
- return_values = True,
801
- value_residual = mem_value_residual,
802
790
  prev_layer_updates = neural_memory_updates
803
791
  )
804
792
 
805
793
  kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
806
794
 
807
795
  else:
808
- (retrieved, next_neural_mem_cache, next_mem_value_residual) = mem.forward_inference(
796
+ (retrieved, next_neural_mem_cache) = mem.forward_inference(
809
797
  mem_input,
810
798
  state = next(neural_mem_caches, None),
811
- return_values = True,
812
- value_residual = mem_value_residual,
813
799
  prev_layer_updates = neural_memory_updates
814
800
  )
815
801
 
816
- if self.mem_add_value_residual:
817
- mem_value_residual = next_mem_value_residual
818
-
819
802
  if prev_neural_mem_update_for_weights:
820
803
  neural_memory_updates = next_neural_mem_cache.updates
821
804
 
@@ -1,5 +1,5 @@
1
1
  import torch
2
- from torch import nn
2
+ from torch import nn, cat
3
3
  import torch.nn.functional as F
4
4
  from torch.nn import Module, ModuleList, Parameter, ParameterList
5
5
 
@@ -67,6 +67,9 @@ def safe_cat(inputs, dim = -2):
67
67
  def identity(t):
68
68
  return t
69
69
 
70
+ def dict_get_shape(td):
71
+ return {k: v.shape for k, v in td.items()}
72
+
70
73
  def pair(v):
71
74
  return (v, v) if not isinstance(v, tuple) else v
72
75
 
@@ -258,7 +261,6 @@ class NeuralMemory(Module):
258
261
  pre_rmsnorm = True,
259
262
  post_rmsnorm = True,
260
263
  qk_rmsnorm = False,
261
- accept_value_residual = False,
262
264
  max_grad_norm: float | None = None,
263
265
  use_accelerated_scan = False,
264
266
  activation: Module | None = None,
@@ -302,19 +304,34 @@ class NeuralMemory(Module):
302
304
  nn.Sigmoid()
303
305
  ) if heads > 1 else None
304
306
 
305
- # memory mlp
307
+ # memory model
306
308
 
307
309
  if not exists(model):
308
310
  model = MemoryMLP(dim_head, **default_model_kwargs)
309
311
 
312
+ # validate memory model
313
+
310
314
  assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
311
315
 
316
+ test_shape = (3, 2, dim_head)
317
+
318
+ with torch.no_grad():
319
+ try:
320
+ test_input = torch.randn(test_shape)
321
+ mem_model_output = model(test_input)
322
+ except:
323
+ raise RuntimeError(f'memory model unable to accept a tensor of shape {test_shape}')
324
+
325
+ assert mem_model_output.shape == test_shape, 'output of memory model needs to be same shape as input'
326
+
312
327
  # the memory is the weights of the model
313
328
 
314
329
  self.memory_model = model
315
330
 
316
331
  self.num_memory_parameter_tensors = len(set(model.parameters()))
317
332
 
333
+ self.init_weight_shape = dict_get_shape(dict(model.named_parameters()))
334
+
318
335
  # the chunk size within the paper where adaptive step, momentum, weight decay are shared
319
336
 
320
337
  self.chunk_size = chunk_size
@@ -343,19 +360,6 @@ class NeuralMemory(Module):
343
360
  self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
344
361
  self.store_memory_loss_fn = store_memory_loss_fn
345
362
 
346
- # value residual learning
347
-
348
- self.learned_value_residual = Sequential(
349
- LinearNoBias(dim, heads),
350
- Rearrange('b n h -> b h n 1'),
351
- nn.Sigmoid()
352
- ) if accept_value_residual else None
353
-
354
- # empty memory embed
355
-
356
- self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
357
- nn.init.normal_(self.empty_memory_embed, std = 0.02)
358
-
359
363
  # `chunk_size` refers to chunk size used for storing to memory model weights
360
364
 
361
365
  chunk_size = self.store_chunk_size
@@ -417,9 +421,6 @@ class NeuralMemory(Module):
417
421
  weights = TensorDict(dict(self.memory_model.named_parameters()))
418
422
  return weights
419
423
 
420
- def init_empty_memory_embed(self, batch, seq_len):
421
- return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
422
-
423
424
  def store_memories(
424
425
  self,
425
426
  seq,
@@ -428,10 +429,7 @@ class NeuralMemory(Module):
428
429
  prev_layer_updates: dict[str, Tensor] | None = None,
429
430
  return_aux_kv_loss = False,
430
431
  chunk_size = None,
431
- value_residual = None
432
432
  ):
433
- assert xnor(exists(value_residual), exists(self.learned_value_residual))
434
-
435
433
  seq_len, heads, chunk_size = seq.shape[-2], self.heads, default(chunk_size, self.store_chunk_size)
436
434
 
437
435
  # handle edge case
@@ -446,7 +444,7 @@ class NeuralMemory(Module):
446
444
 
447
445
  round_down_seq_len = round_down_multiple(seq_len, chunk_size)
448
446
 
449
- seq = seq[:, :round_down_seq_len]
447
+ seq, remainder = seq[:, :round_down_seq_len], seq[:, round_down_seq_len:]
450
448
 
451
449
  # per sample grad function
452
450
 
@@ -499,14 +497,6 @@ class NeuralMemory(Module):
499
497
 
500
498
  keys = self.k_norm(keys)
501
499
 
502
- # maybe value residual learning
503
-
504
- orig_values = values
505
-
506
- if exists(self.learned_value_residual):
507
- mix = self.learned_value_residual(seq)
508
- values = values.lerp(value_residual, mix)
509
-
510
500
  # take care of chunking
511
501
 
512
502
  keys, values = tuple(rearrange(t, 'b h (n c) d -> (b h n) c d', c = chunk_size) for t in (keys, values))
@@ -581,13 +571,15 @@ class NeuralMemory(Module):
581
571
  if has_momentum:
582
572
  next_momentum[param_name] = inverse_pack(momentum)
583
573
 
584
- # compute next states for inference, or titans-xl like training
574
+ # determine next state for the storing of memories
585
575
 
586
576
  next_state = (next_last_update, next_last_momentum)
587
577
 
578
+ next_store_state = NeuralMemCache(seq_len, remainder, next_state, updates)
579
+
588
580
  # returns
589
581
 
590
- output = (updates, next_state, orig_values)
582
+ output = (updates, next_store_state)
591
583
 
592
584
  if not return_aux_kv_loss:
593
585
  return output
@@ -606,16 +598,18 @@ class NeuralMemory(Module):
606
598
 
607
599
  seq = self.retrieve_norm(seq)
608
600
 
609
- if seq_len < chunk_size:
610
- return self.init_empty_memory_embed(batch, seq_len)
601
+ assert seq_len >= chunk_size, 'must be handled outside of retrieve'
602
+
603
+ needs_pad = chunk_size > 1
611
604
 
612
- seq = seq[:, (chunk_size - 1):]
613
- curtailed_seq_len = seq.shape[-2]
605
+ if needs_pad:
606
+ seq = pad_at_dim(seq, (1, 0), dim = 1)
607
+ seq_len_plus_one = seq.shape[-2]
614
608
 
615
- next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
609
+ next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
616
610
 
617
- padding = next_seq_len - curtailed_seq_len
618
- seq = pad_at_dim(seq, (0, padding), dim = 1)
611
+ padding = next_seq_len - seq_len_plus_one
612
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
619
613
 
620
614
  # the parameters of the memory model stores the memories of the key / values
621
615
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
@@ -639,7 +633,9 @@ class NeuralMemory(Module):
639
633
 
640
634
  # fetch values from memory model
641
635
 
642
- curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
636
+ if dict_get_shape(curr_weights) != self.init_weight_shape:
637
+ curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
638
+
643
639
  queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
644
640
 
645
641
  # forward functional call
@@ -665,10 +661,10 @@ class NeuralMemory(Module):
665
661
 
666
662
  # restore, pad with empty memory embed
667
663
 
668
- empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
669
- values = torch.cat((empty_memory_embeds, values), dim = -2)
664
+ if needs_pad:
665
+ values = values[:, 1:(seq_len + 1)]
670
666
 
671
- return values[:, :seq_len]
667
+ return values
672
668
 
673
669
  @torch.no_grad()
674
670
  def forward_inference(
@@ -676,8 +672,6 @@ class NeuralMemory(Module):
676
672
  token: Tensor,
677
673
  state = None,
678
674
  prev_layer_updates: dict[str, Tensor] | None = None,
679
- return_values = False,
680
- value_residual = None,
681
675
  ):
682
676
 
683
677
  # unpack previous state
@@ -704,12 +698,9 @@ class NeuralMemory(Module):
704
698
  # early return empty memory, when no memories are stored for steps < first chunk size
705
699
 
706
700
  if curr_seq_len < self.chunk_size:
707
- empty_mem = self.init_empty_memory_embed(batch, 1)
708
-
709
- output = empty_mem, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
701
+ retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
710
702
 
711
- if return_values:
712
- output = (*output, self.zero)
703
+ output = retrieve, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
713
704
 
714
705
  return output
715
706
 
@@ -728,20 +719,18 @@ class NeuralMemory(Module):
728
719
  prev_layer_updates = TensorDict(prev_layer_updates)
729
720
  prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
730
721
 
731
- values = None
732
-
733
722
  if store_seq_cache_len == self.chunk_size:
734
723
 
735
- next_updates, next_states, values = self.store_memories(
724
+ next_updates, store_state = self.store_memories(
736
725
  cache_store_seq,
737
726
  weights,
738
727
  past_state = past_states,
739
728
  prev_layer_updates = prev_layer_updates,
740
- value_residual = value_residual
741
729
  )
742
730
 
743
731
  updates = next_updates
744
732
  cache_store_seq = None
733
+ next_states = store_state.states
745
734
 
746
735
  # retrieve
747
736
 
@@ -749,14 +738,9 @@ class NeuralMemory(Module):
749
738
 
750
739
  # next state tuple
751
740
 
752
- next_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
741
+ next_store_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
753
742
 
754
- output = (retrieved, next_state)
755
-
756
- if return_values:
757
- output = (*output, values)
758
-
759
- return output
743
+ return retrieved, next_store_state
760
744
 
761
745
  def forward(
762
746
  self,
@@ -767,50 +751,45 @@ class NeuralMemory(Module):
767
751
  return_aux_kv_loss = False,
768
752
  chunk_size = None,
769
753
  store_chunk_size = None,
770
- return_values = False,
771
- value_residual = None,
772
754
  return_next_state = False,
773
755
  prev_layer_updates: dict[str, Tensor] | None = None
774
756
  ):
775
757
  batch, seq_len = seq.shape[:2]
776
758
 
759
+ if not exists(mem_model_weights):
760
+ mem_model_weights = self.init_weights()
761
+
777
762
  if seq_len < self.retrieve_chunk_size:
778
- out = self.init_empty_memory_embed(batch, seq_len)
763
+ retrieved = self.retrieve_memories(seq, mem_model_weights, chunk_size = 1)
779
764
 
780
765
  next_store_state = NeuralMemCache(seq_len, seq, None, None)
781
766
 
782
- out = (out, next_store_state)
783
-
784
- if return_values:
785
- out = (*out, self.zero)
767
+ out = (retrieved, next_store_state)
786
768
 
787
769
  if not return_aux_kv_loss:
788
770
  return out
789
771
 
790
772
  return out, self.zero
791
773
 
792
- if not exists(mem_model_weights):
793
- mem_model_weights = self.init_weights()
794
-
795
774
  # store
796
775
 
797
776
  store_seq = default(store_seq, seq)
798
777
 
799
- store_seq_len = store_seq.shape[-2]
800
- store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
801
- remainder = store_seq_len % store_chunk_size
802
-
803
- (updates, next_state, values), aux_kv_recon_loss = self.store_memories(
778
+ (updates, next_store_state), aux_kv_recon_loss = self.store_memories(
804
779
  store_seq,
805
780
  mem_model_weights,
806
781
  chunk_size = store_chunk_size,
807
782
  prev_layer_updates = prev_layer_updates,
808
- value_residual = value_residual,
809
783
  return_aux_kv_loss = True
810
784
  )
811
785
 
812
786
  # retrieve
813
787
 
788
+ if exists(prev_layer_updates):
789
+ prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
790
+
791
+ updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
792
+
814
793
  retrieved = self.retrieve_memories(
815
794
  seq,
816
795
  mem_model_weights + updates,
@@ -818,21 +797,8 @@ class NeuralMemory(Module):
818
797
  prev_layer_updates = prev_layer_updates
819
798
  )
820
799
 
821
- # determine state for the storing of memories
822
- # for transformer-xl like training with neural memory as well as inferencing with initial prompt
823
-
824
- cache_store_seq = None
825
-
826
- if remainder > 0:
827
- cache_store_seq = store_seq[:, -remainder:]
828
-
829
- next_store_state = NeuralMemCache(seq_len, cache_store_seq, next_state, updates)
830
-
831
800
  output = (retrieved, next_store_state)
832
801
 
833
- if return_values:
834
- output = (*output, values)
835
-
836
802
  if not return_aux_kv_loss:
837
803
  return output
838
804
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.1
3
+ Version: 0.2.5
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
4
+ titans_pytorch/memory_models.py,sha256=Ew28waD9gf1wn-5Nkdc676u1I92IqzaOAw-tv0JXMwc,3777
5
+ titans_pytorch/neural_memory.py,sha256=YiBsMiqYn-Hva4yhxfaqkGV857vZIASxi5Z0TT0FC10,24606
6
+ titans_pytorch-0.2.5.dist-info/METADATA,sha256=x3RePuTDf3rUT3vtvge1X3Ry18Y3tV_swCgycbtSCjQ,6819
7
+ titans_pytorch-0.2.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.5.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=kqW90mpbFf1ZJ_mMkd6v9EQ5J__TwKMPy5cjHJF_26A,26742
4
- titans_pytorch/memory_models.py,sha256=LI9T36XB6YXIvvGWRw0ZMDlGpRC6KIv03OPzME2VAaU,3772
5
- titans_pytorch/neural_memory.py,sha256=vmKPOAlXBPXBnYPODrg_reWaIcr1xwtfQmuptGS6e5A,25559
6
- titans_pytorch-0.2.1.dist-info/METADATA,sha256=HPdcQb4SlT-eLFzOYLMwGInEKegL4M4yIpKWt1a6DTs,6819
7
- titans_pytorch-0.2.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.2.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.2.1.dist-info/RECORD,,