titans-pytorch 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -62,6 +62,7 @@ from rotary_embedding_torch import RotaryEmbedding
62
62
  # hyper connections / attend from x-transformers, which handles different queries and key lengths better
63
63
 
64
64
  from x_transformers.attend import Attend
65
+
65
66
  from hyper_connections import get_init_and_expand_reduce_stream_functions
66
67
 
67
68
  # proposed neural memory
@@ -515,7 +516,7 @@ class MemoryAsContextTransformer(Module):
515
516
 
516
517
  # hyper conection
517
518
 
518
- init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
519
+ init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1)
519
520
 
520
521
  self.layers = ModuleList([])
521
522
 
@@ -553,7 +554,7 @@ class MemoryAsContextTransformer(Module):
553
554
  mem_hyper_conn = None
554
555
 
555
556
  if layer in neural_memory_layers:
556
- mem_hyper_conn = init_hyper_conn(dim = dim, add_branch_out_to_residual = not neural_mem_gate_attn_output)
557
+ mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
557
558
 
558
559
  mem = NeuralMemory(
559
560
  dim = dim,
@@ -571,8 +572,8 @@ class MemoryAsContextTransformer(Module):
571
572
  self.layers.append(ModuleList([
572
573
  mem_hyper_conn,
573
574
  mem,
574
- init_hyper_conn(dim = dim, branch = attn),
575
- init_hyper_conn(dim = dim, branch = ff)
575
+ init_hyper_conn(branch = attn),
576
+ init_hyper_conn(branch = ff)
576
577
  ]))
577
578
 
578
579
  self.norm = nn.RMSNorm(dim)
@@ -3,6 +3,7 @@ from typing import Callable
3
3
 
4
4
  import math
5
5
  from functools import partial
6
+ from itertools import zip_longest
6
7
  from collections import namedtuple
7
8
 
8
9
  import torch
@@ -21,21 +22,24 @@ from titans_pytorch.memory_models import(
21
22
  )
22
23
 
23
24
  import einx
24
- from einops import rearrange, repeat, reduce, pack, unpack
25
+ from einops import einsum, rearrange, repeat, reduce, pack, unpack
25
26
  from einops.layers.torch import Rearrange, Reduce
26
27
 
27
28
  """
28
29
  ein notation:
29
30
  b - batch
31
+ h - heads
32
+ bh - batch and heads
30
33
  n - sequence
31
34
  d - feature dimension
32
35
  c - intra-chunk
33
36
  w - num memory network weight parameters
37
+ o - momentum orders
34
38
  """
35
39
 
36
40
  LinearNoBias = partial(Linear, bias = False)
37
41
 
38
- NeuralMemCache = namedtuple('NeuralMemCache', [
42
+ NeuralMemState = namedtuple('NeuralMemState', [
39
43
  'seq_index',
40
44
  'weights',
41
45
  'cache_store_segment',
@@ -224,6 +228,8 @@ class NeuralMemory(Module):
224
228
  per_head_learned_parameters = True,
225
229
  attn_pool_chunks = False,
226
230
  momentum = True,
231
+ momentum_order = 1,
232
+ learned_momentum_combine = False,
227
233
  pre_rmsnorm = True,
228
234
  post_rmsnorm = False,
229
235
  qk_rmsnorm = False,
@@ -367,12 +373,7 @@ class NeuralMemory(Module):
367
373
  else:
368
374
  self.reduce_to_chunk_rep = AttentionPool(dim, chunk_size = chunk_size)
369
375
 
370
- # learned adaptive learning rate and momentum
371
-
372
- self.to_momentum = Sequential(
373
- nn.Linear(dim, heads),
374
- Rearrange('b n h -> (b h) n 1')
375
- ) if momentum else None
376
+ # learned adaptive learning rate
376
377
 
377
378
  self.to_adaptive_step = Sequential(
378
379
  nn.Linear(dim, heads),
@@ -384,6 +385,26 @@ class NeuralMemory(Module):
384
385
 
385
386
  self.adaptive_step_transform = adaptive_step_transform
386
387
 
388
+ # momentum related
389
+
390
+ self.to_momentum = Sequential(
391
+ nn.Linear(dim, heads * momentum_order),
392
+ Rearrange('b n (h o) -> o (b h) n 1', o = momentum_order)
393
+ ) if momentum else None
394
+
395
+ self.momentum_order = momentum_order
396
+ self.to_learned_momentum_combine = None
397
+
398
+ if learned_momentum_combine:
399
+ assert momentum
400
+ assert momentum_order > 1, 'only second order momentum allowed for now, but may allow learned combination of zeroth'
401
+
402
+ self.to_learned_momentum_combine = Sequential(
403
+ nn.Linear(dim, heads * momentum_order),
404
+ nn.Softmax(dim = -1),
405
+ Rearrange('b n (h o) -> o (b h) n', h = heads)
406
+ )
407
+
387
408
  # per layer learning rate modulation
388
409
 
389
410
  self.to_layer_modulation = Sequential(
@@ -463,9 +484,9 @@ class NeuralMemory(Module):
463
484
  zeros = self.memory_model_parameter_dict.clone().zero_()
464
485
 
465
486
  if self.per_head_learned_parameters:
466
- zeros = repeat_dict_values(zeros, 'h ... -> (b h) ...', b = batch)
487
+ zeros = repeat_dict_values(zeros, 'h ... -> o (b h) ...', b = batch, o = self.momentum_order)
467
488
  else:
468
- zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
489
+ zeros = repeat_dict_values(zeros, '... -> o bh ...', bh = batch * self.heads, o = self.momentum_order)
469
490
 
470
491
  return zeros
471
492
 
@@ -518,6 +539,11 @@ class NeuralMemory(Module):
518
539
  if has_momentum:
519
540
  adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
520
541
 
542
+ learned_combine = exists(self.to_learned_momentum_combine)
543
+
544
+ if learned_combine:
545
+ combine_momentums = self.to_learned_momentum_combine(chunked_seq)
546
+
521
547
  if need_layer_lr_mod:
522
548
  layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation
523
549
 
@@ -585,7 +611,7 @@ class NeuralMemory(Module):
585
611
 
586
612
  # negative gradients, adaptive lr already applied as loss weight
587
613
 
588
- surprises = grads.apply(lambda t: -t)
614
+ surprises = grads.mul(-1)
589
615
 
590
616
  # past states
591
617
 
@@ -603,7 +629,7 @@ class NeuralMemory(Module):
603
629
 
604
630
  if num_chunks == 0:
605
631
  updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...')
606
- next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, past_state, updates)
632
+ next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, past_state, updates)
607
633
 
608
634
  output = (updates, next_store_state)
609
635
 
@@ -611,7 +637,6 @@ class NeuralMemory(Module):
611
637
 
612
638
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
613
639
 
614
- next_momentum = TensorDict() if has_momentum else None
615
640
  updates = TensorDict()
616
641
 
617
642
  next_last_update = TensorDict()
@@ -624,32 +649,44 @@ class NeuralMemory(Module):
624
649
  # derive momentum with associative scan - eq (10)
625
650
 
626
651
  if has_momentum:
652
+ momentum = surprise
653
+
654
+ momentums = [] # stores all momentum orders starting with first, to generalize to Nth order momentum
655
+
627
656
  last_momentum = past_last_momentum[param_name]
628
- update = self.assoc_scan(adaptive_momentum, surprise, prev = last_momentum) # momentum is S / surprise in the paper
629
- momentum = update
630
- next_last_momentum[param_name] = momentum[:, -1]
657
+
658
+ # go from first order momentum all the way to the Nth
659
+
660
+ for one_adaptive_momentum, one_last_momentum in zip_longest(adaptive_momentum, last_momentum):
661
+ momentum = self.assoc_scan(one_adaptive_momentum, momentum, prev = one_last_momentum) # momentum is S / surprise in the paper
662
+
663
+ momentums.append(momentum)
664
+
665
+ momentums = torch.stack(momentums)
666
+
667
+ next_last_momentum[param_name] = momentums[:, :, -1] # momentums shape is Float['o bh n 1']
668
+
669
+ if not learned_combine:
670
+ update = momentums[-1]
671
+ else:
672
+ update = einsum(combine_momentums, momentums, 'o b n, o b n ... -> b n ...')
631
673
 
632
674
  # use associative scan again for learned forgetting (weight decay) - eq (13)
633
675
 
634
676
  update = self.assoc_scan(1. - decay_factor, update, prev = last_update, remove_prev = False)
635
- next_last_update[param_name] = update[:, -1]
636
677
 
637
678
  updates[param_name] = update
638
-
639
- if has_momentum:
640
- next_momentum[param_name] = momentum
679
+ next_last_update[param_name] = update[:, -1]
641
680
 
642
681
  # determine next state for the storing of memories
643
682
 
644
683
  next_state = (next_last_update, next_last_momentum)
645
684
 
646
- next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, next_state, updates)
647
-
648
- # returns
685
+ next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, next_state, updates)
649
686
 
650
- output = (updates, next_store_state)
687
+ # return updates to neural memory at all chunked timesteps + neural mem cache / state to be fed back
651
688
 
652
- return output
689
+ return updates, next_store_state
653
690
 
654
691
  def retrieve_memories(
655
692
  self,
@@ -746,7 +783,7 @@ class NeuralMemory(Module):
746
783
  self,
747
784
  seq,
748
785
  store_seq = None,
749
- state: NeuralMemCache | None = None,
786
+ state: NeuralMemState | None = None,
750
787
  prev_weights = None
751
788
  ):
752
789
  if seq.ndim == 2:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.9
3
+ Version: 0.3.11
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: axial-positional-embedding>=0.3.10
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
- Requires-Dist: hyper-connections>=0.1.9
41
+ Requires-Dist: hyper-connections>=0.1.10
42
42
  Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
+ titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
+ titans_pytorch/mac_transformer.py,sha256=EyqA53HBqvAr4UNZUs37LR6IltyEfA7FKEV54YzVYlg,24945
4
+ titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
+ titans_pytorch/neural_memory.py,sha256=7YglrQaDpKS2hbpBBwx7PmqhJdjyvFEPZDt_QXmnUMM,28878
6
+ titans_pytorch-0.3.11.dist-info/METADATA,sha256=xAEvavDiCj__5Bl_5UXaG__BycdUB2DzHOud-nwsn1c,6817
7
+ titans_pytorch-0.3.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.11.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
- titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
- titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
- titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=2Ffq5fob6_vJVUs1jIai3BJhfsfypKffJEb0QwRRdMk,27325
6
- titans_pytorch-0.3.9.dist-info/METADATA,sha256=EUuuqHl8jPUPIG8m7xV5esN_2yDNaPdD-H8qFeDxWGo,6815
7
- titans_pytorch-0.3.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.9.dist-info/RECORD,,