titans-pytorch 0.3.10__py3-none-any.whl → 0.3.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -62,6 +62,7 @@ from rotary_embedding_torch import RotaryEmbedding
62
62
  # hyper connections / attend from x-transformers, which handles different queries and key lengths better
63
63
 
64
64
  from x_transformers.attend import Attend
65
+
65
66
  from hyper_connections import get_init_and_expand_reduce_stream_functions
66
67
 
67
68
  # proposed neural memory
@@ -515,7 +516,7 @@ class MemoryAsContextTransformer(Module):
515
516
 
516
517
  # hyper conection
517
518
 
518
- init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
519
+ init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1)
519
520
 
520
521
  self.layers = ModuleList([])
521
522
 
@@ -553,7 +554,7 @@ class MemoryAsContextTransformer(Module):
553
554
  mem_hyper_conn = None
554
555
 
555
556
  if layer in neural_memory_layers:
556
- mem_hyper_conn = init_hyper_conn(dim = dim, add_branch_out_to_residual = not neural_mem_gate_attn_output)
557
+ mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
557
558
 
558
559
  mem = NeuralMemory(
559
560
  dim = dim,
@@ -571,8 +572,8 @@ class MemoryAsContextTransformer(Module):
571
572
  self.layers.append(ModuleList([
572
573
  mem_hyper_conn,
573
574
  mem,
574
- init_hyper_conn(dim = dim, branch = attn),
575
- init_hyper_conn(dim = dim, branch = ff)
575
+ init_hyper_conn(branch = attn),
576
+ init_hyper_conn(branch = ff)
576
577
  ]))
577
578
 
578
579
  self.norm = nn.RMSNorm(dim)
@@ -7,7 +7,7 @@ from itertools import zip_longest
7
7
  from collections import namedtuple
8
8
 
9
9
  import torch
10
- from torch import nn, cat, tensor, Tensor
10
+ from torch import nn, stack, cat, tensor, Tensor
11
11
  import torch.nn.functional as F
12
12
  from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict
13
13
  from torch.func import functional_call, vmap, grad
@@ -39,7 +39,7 @@ o - momentum orders
39
39
 
40
40
  LinearNoBias = partial(Linear, bias = False)
41
41
 
42
- NeuralMemCache = namedtuple('NeuralMemCache', [
42
+ NeuralMemState = namedtuple('NeuralMemState', [
43
43
  'seq_index',
44
44
  'weights',
45
45
  'cache_store_segment',
@@ -230,6 +230,7 @@ class NeuralMemory(Module):
230
230
  momentum = True,
231
231
  momentum_order = 1,
232
232
  learned_momentum_combine = False,
233
+ learned_combine_include_zeroth = False,
233
234
  pre_rmsnorm = True,
234
235
  post_rmsnorm = False,
235
236
  qk_rmsnorm = False,
@@ -399,12 +400,17 @@ class NeuralMemory(Module):
399
400
  assert momentum
400
401
  assert momentum_order > 1, 'only second order momentum allowed for now, but may allow learned combination of zeroth'
401
402
 
403
+ if learned_combine_include_zeroth:
404
+ momentum_order += 1
405
+
402
406
  self.to_learned_momentum_combine = Sequential(
403
407
  nn.Linear(dim, heads * momentum_order),
404
408
  nn.Softmax(dim = -1),
405
409
  Rearrange('b n (h o) -> o (b h) n', h = heads)
406
410
  )
407
411
 
412
+ self.learned_combine_include_zeroth = learned_combine_include_zeroth
413
+
408
414
  # per layer learning rate modulation
409
415
 
410
416
  self.to_layer_modulation = Sequential(
@@ -629,7 +635,7 @@ class NeuralMemory(Module):
629
635
 
630
636
  if num_chunks == 0:
631
637
  updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...')
632
- next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, past_state, updates)
638
+ next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, past_state, updates)
633
639
 
634
640
  output = (updates, next_store_state)
635
641
 
@@ -662,10 +668,14 @@ class NeuralMemory(Module):
662
668
 
663
669
  momentums.append(momentum)
664
670
 
665
- momentums = torch.stack(momentums)
671
+ momentums = stack(momentums)
666
672
 
667
673
  next_last_momentum[param_name] = momentums[:, :, -1] # momentums shape is Float['o bh n 1']
668
674
 
675
+ if learned_combine and self.learned_combine_include_zeroth:
676
+ # add the original surprise if learned combination of momentums
677
+ momentums = cat((rearrange(surprise, '... -> 1 ...'), momentums), dim = 0)
678
+
669
679
  if not learned_combine:
670
680
  update = momentums[-1]
671
681
  else:
@@ -682,13 +692,11 @@ class NeuralMemory(Module):
682
692
 
683
693
  next_state = (next_last_update, next_last_momentum)
684
694
 
685
- next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, next_state, updates)
686
-
687
- # returns
695
+ next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, next_state, updates)
688
696
 
689
- output = (updates, next_store_state)
697
+ # return updates to neural memory at all chunked timesteps + neural mem cache / state to be fed back
690
698
 
691
- return output
699
+ return updates, next_store_state
692
700
 
693
701
  def retrieve_memories(
694
702
  self,
@@ -785,7 +793,7 @@ class NeuralMemory(Module):
785
793
  self,
786
794
  seq,
787
795
  store_seq = None,
788
- state: NeuralMemCache | None = None,
796
+ state: NeuralMemState | None = None,
789
797
  prev_weights = None
790
798
  ):
791
799
  if seq.ndim == 2:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.10
3
+ Version: 0.3.12
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: axial-positional-embedding>=0.3.10
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
- Requires-Dist: hyper-connections>=0.1.9
41
+ Requires-Dist: hyper-connections>=0.1.10
42
42
  Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
+ titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
+ titans_pytorch/mac_transformer.py,sha256=EyqA53HBqvAr4UNZUs37LR6IltyEfA7FKEV54YzVYlg,24945
4
+ titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
+ titans_pytorch/neural_memory.py,sha256=VmUAS1xOM0ZfearWIzQrX_P7HI69viuwrg9M7BQByeE,29349
6
+ titans_pytorch-0.3.12.dist-info/METADATA,sha256=02OsMYNITFLjnKJgis8eUHxwcdH2aVbA_D-QK24TYbg,6817
7
+ titans_pytorch-0.3.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.12.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
- titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
- titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
- titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=BeOnq41gjZeq-XJFjkHE44F9dLzsg9mm36EBYZ4wHMA,28814
6
- titans_pytorch-0.3.10.dist-info/METADATA,sha256=sA_Dx_x5RMcpz5-vUPDHuz__tHYfKzs4W_BgY4CHPdk,6816
7
- titans_pytorch-0.3.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.10.dist-info/RECORD,,