titans-pytorch 0.3.9__tar.gz → 0.3.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.9
3
+ Version: 0.3.11
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: axial-positional-embedding>=0.3.10
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
- Requires-Dist: hyper-connections>=0.1.9
41
+ Requires-Dist: hyper-connections>=0.1.10
42
42
  Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.3.9"
3
+ version = "0.3.11"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -29,7 +29,7 @@ dependencies = [
29
29
  "axial_positional_embedding>=0.3.10",
30
30
  "einops>=0.8.0",
31
31
  "einx>=0.3.0",
32
- "hyper-connections>=0.1.9",
32
+ "hyper-connections>=0.1.10",
33
33
  "Ninja",
34
34
  "rotary-embedding-torch",
35
35
  "tensordict",
@@ -60,6 +60,26 @@ def test_titans(
60
60
 
61
61
  assert seq.shape == retrieved.shape
62
62
 
63
+ @pytest.mark.parametrize('learned_momentum_combine', (False, True))
64
+ def test_titans_second_order_momentum(
65
+ learned_momentum_combine
66
+ ):
67
+
68
+ mem = NeuralMemory(
69
+ dim = 384,
70
+ dim_head = 64,
71
+ heads = 2,
72
+ chunk_size = 1,
73
+ batch_size = 2,
74
+ momentum_order = 2,
75
+ learned_momentum_combine = learned_momentum_combine
76
+ )
77
+
78
+ seq = torch.randn(2, 5, 384)
79
+
80
+ parallel_retrieved, state = mem(seq)
81
+ assert seq.shape == parallel_retrieved.shape
82
+
63
83
  def test_titans_attn_memory():
64
84
  from titans_pytorch.memory_models import MemoryAttention
65
85
 
@@ -318,12 +338,16 @@ def test_flex(
318
338
 
319
339
  assert torch.allclose(out_flex, out_non_flex, atol = 1e-5)
320
340
 
321
- @torch_default_dtype(torch.float64)
322
- def test_assoc_scan():
341
+ @pytest.mark.parametrize('use_accelerated', (True, False))
342
+ def test_assoc_scan(
343
+ use_accelerated
344
+ ):
323
345
  from titans_pytorch.neural_memory import AssocScan
324
- torch.set_default_dtype(torch.float64)
325
346
 
326
- scan = AssocScan()
347
+ if use_accelerated and not torch.cuda.is_available():
348
+ pytest.skip()
349
+
350
+ scan = AssocScan(use_accelerated = use_accelerated)
327
351
 
328
352
  seq_len = 128
329
353
  mid_point = seq_len // 2
@@ -331,6 +355,10 @@ def test_assoc_scan():
331
355
  gates = torch.randn(2, seq_len, 16).sigmoid()
332
356
  inputs = torch.randn(2, seq_len, 16)
333
357
 
358
+ if use_accelerated:
359
+ gates = gates.cuda()
360
+ inputs = inputs.cuda()
361
+
334
362
  output = scan(gates, inputs)
335
363
 
336
364
  gates1, gates2 = gates[:, :mid_point], gates[:, mid_point:]
@@ -341,4 +369,4 @@ def test_assoc_scan():
341
369
  second_half = scan(gates2, inputs2, prev = first_half[:, -1])
342
370
  assert second_half.shape == inputs2.shape
343
371
 
344
- assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-6)
372
+ assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-5)
@@ -62,6 +62,7 @@ from rotary_embedding_torch import RotaryEmbedding
62
62
  # hyper connections / attend from x-transformers, which handles different queries and key lengths better
63
63
 
64
64
  from x_transformers.attend import Attend
65
+
65
66
  from hyper_connections import get_init_and_expand_reduce_stream_functions
66
67
 
67
68
  # proposed neural memory
@@ -515,7 +516,7 @@ class MemoryAsContextTransformer(Module):
515
516
 
516
517
  # hyper conection
517
518
 
518
- init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
519
+ init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1)
519
520
 
520
521
  self.layers = ModuleList([])
521
522
 
@@ -553,7 +554,7 @@ class MemoryAsContextTransformer(Module):
553
554
  mem_hyper_conn = None
554
555
 
555
556
  if layer in neural_memory_layers:
556
- mem_hyper_conn = init_hyper_conn(dim = dim, add_branch_out_to_residual = not neural_mem_gate_attn_output)
557
+ mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
557
558
 
558
559
  mem = NeuralMemory(
559
560
  dim = dim,
@@ -571,8 +572,8 @@ class MemoryAsContextTransformer(Module):
571
572
  self.layers.append(ModuleList([
572
573
  mem_hyper_conn,
573
574
  mem,
574
- init_hyper_conn(dim = dim, branch = attn),
575
- init_hyper_conn(dim = dim, branch = ff)
575
+ init_hyper_conn(branch = attn),
576
+ init_hyper_conn(branch = ff)
576
577
  ]))
577
578
 
578
579
  self.norm = nn.RMSNorm(dim)
@@ -3,6 +3,7 @@ from typing import Callable
3
3
 
4
4
  import math
5
5
  from functools import partial
6
+ from itertools import zip_longest
6
7
  from collections import namedtuple
7
8
 
8
9
  import torch
@@ -21,21 +22,24 @@ from titans_pytorch.memory_models import(
21
22
  )
22
23
 
23
24
  import einx
24
- from einops import rearrange, repeat, reduce, pack, unpack
25
+ from einops import einsum, rearrange, repeat, reduce, pack, unpack
25
26
  from einops.layers.torch import Rearrange, Reduce
26
27
 
27
28
  """
28
29
  ein notation:
29
30
  b - batch
31
+ h - heads
32
+ bh - batch and heads
30
33
  n - sequence
31
34
  d - feature dimension
32
35
  c - intra-chunk
33
36
  w - num memory network weight parameters
37
+ o - momentum orders
34
38
  """
35
39
 
36
40
  LinearNoBias = partial(Linear, bias = False)
37
41
 
38
- NeuralMemCache = namedtuple('NeuralMemCache', [
42
+ NeuralMemState = namedtuple('NeuralMemState', [
39
43
  'seq_index',
40
44
  'weights',
41
45
  'cache_store_segment',
@@ -224,6 +228,8 @@ class NeuralMemory(Module):
224
228
  per_head_learned_parameters = True,
225
229
  attn_pool_chunks = False,
226
230
  momentum = True,
231
+ momentum_order = 1,
232
+ learned_momentum_combine = False,
227
233
  pre_rmsnorm = True,
228
234
  post_rmsnorm = False,
229
235
  qk_rmsnorm = False,
@@ -367,12 +373,7 @@ class NeuralMemory(Module):
367
373
  else:
368
374
  self.reduce_to_chunk_rep = AttentionPool(dim, chunk_size = chunk_size)
369
375
 
370
- # learned adaptive learning rate and momentum
371
-
372
- self.to_momentum = Sequential(
373
- nn.Linear(dim, heads),
374
- Rearrange('b n h -> (b h) n 1')
375
- ) if momentum else None
376
+ # learned adaptive learning rate
376
377
 
377
378
  self.to_adaptive_step = Sequential(
378
379
  nn.Linear(dim, heads),
@@ -384,6 +385,26 @@ class NeuralMemory(Module):
384
385
 
385
386
  self.adaptive_step_transform = adaptive_step_transform
386
387
 
388
+ # momentum related
389
+
390
+ self.to_momentum = Sequential(
391
+ nn.Linear(dim, heads * momentum_order),
392
+ Rearrange('b n (h o) -> o (b h) n 1', o = momentum_order)
393
+ ) if momentum else None
394
+
395
+ self.momentum_order = momentum_order
396
+ self.to_learned_momentum_combine = None
397
+
398
+ if learned_momentum_combine:
399
+ assert momentum
400
+ assert momentum_order > 1, 'only second order momentum allowed for now, but may allow learned combination of zeroth'
401
+
402
+ self.to_learned_momentum_combine = Sequential(
403
+ nn.Linear(dim, heads * momentum_order),
404
+ nn.Softmax(dim = -1),
405
+ Rearrange('b n (h o) -> o (b h) n', h = heads)
406
+ )
407
+
387
408
  # per layer learning rate modulation
388
409
 
389
410
  self.to_layer_modulation = Sequential(
@@ -463,9 +484,9 @@ class NeuralMemory(Module):
463
484
  zeros = self.memory_model_parameter_dict.clone().zero_()
464
485
 
465
486
  if self.per_head_learned_parameters:
466
- zeros = repeat_dict_values(zeros, 'h ... -> (b h) ...', b = batch)
487
+ zeros = repeat_dict_values(zeros, 'h ... -> o (b h) ...', b = batch, o = self.momentum_order)
467
488
  else:
468
- zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
489
+ zeros = repeat_dict_values(zeros, '... -> o bh ...', bh = batch * self.heads, o = self.momentum_order)
469
490
 
470
491
  return zeros
471
492
 
@@ -518,6 +539,11 @@ class NeuralMemory(Module):
518
539
  if has_momentum:
519
540
  adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
520
541
 
542
+ learned_combine = exists(self.to_learned_momentum_combine)
543
+
544
+ if learned_combine:
545
+ combine_momentums = self.to_learned_momentum_combine(chunked_seq)
546
+
521
547
  if need_layer_lr_mod:
522
548
  layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation
523
549
 
@@ -585,7 +611,7 @@ class NeuralMemory(Module):
585
611
 
586
612
  # negative gradients, adaptive lr already applied as loss weight
587
613
 
588
- surprises = grads.apply(lambda t: -t)
614
+ surprises = grads.mul(-1)
589
615
 
590
616
  # past states
591
617
 
@@ -603,7 +629,7 @@ class NeuralMemory(Module):
603
629
 
604
630
  if num_chunks == 0:
605
631
  updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...')
606
- next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, past_state, updates)
632
+ next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, past_state, updates)
607
633
 
608
634
  output = (updates, next_store_state)
609
635
 
@@ -611,7 +637,6 @@ class NeuralMemory(Module):
611
637
 
612
638
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
613
639
 
614
- next_momentum = TensorDict() if has_momentum else None
615
640
  updates = TensorDict()
616
641
 
617
642
  next_last_update = TensorDict()
@@ -624,32 +649,44 @@ class NeuralMemory(Module):
624
649
  # derive momentum with associative scan - eq (10)
625
650
 
626
651
  if has_momentum:
652
+ momentum = surprise
653
+
654
+ momentums = [] # stores all momentum orders starting with first, to generalize to Nth order momentum
655
+
627
656
  last_momentum = past_last_momentum[param_name]
628
- update = self.assoc_scan(adaptive_momentum, surprise, prev = last_momentum) # momentum is S / surprise in the paper
629
- momentum = update
630
- next_last_momentum[param_name] = momentum[:, -1]
657
+
658
+ # go from first order momentum all the way to the Nth
659
+
660
+ for one_adaptive_momentum, one_last_momentum in zip_longest(adaptive_momentum, last_momentum):
661
+ momentum = self.assoc_scan(one_adaptive_momentum, momentum, prev = one_last_momentum) # momentum is S / surprise in the paper
662
+
663
+ momentums.append(momentum)
664
+
665
+ momentums = torch.stack(momentums)
666
+
667
+ next_last_momentum[param_name] = momentums[:, :, -1] # momentums shape is Float['o bh n 1']
668
+
669
+ if not learned_combine:
670
+ update = momentums[-1]
671
+ else:
672
+ update = einsum(combine_momentums, momentums, 'o b n, o b n ... -> b n ...')
631
673
 
632
674
  # use associative scan again for learned forgetting (weight decay) - eq (13)
633
675
 
634
676
  update = self.assoc_scan(1. - decay_factor, update, prev = last_update, remove_prev = False)
635
- next_last_update[param_name] = update[:, -1]
636
677
 
637
678
  updates[param_name] = update
638
-
639
- if has_momentum:
640
- next_momentum[param_name] = momentum
679
+ next_last_update[param_name] = update[:, -1]
641
680
 
642
681
  # determine next state for the storing of memories
643
682
 
644
683
  next_state = (next_last_update, next_last_momentum)
645
684
 
646
- next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, next_state, updates)
647
-
648
- # returns
685
+ next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, next_state, updates)
649
686
 
650
- output = (updates, next_store_state)
687
+ # return updates to neural memory at all chunked timesteps + neural mem cache / state to be fed back
651
688
 
652
- return output
689
+ return updates, next_store_state
653
690
 
654
691
  def retrieve_memories(
655
692
  self,
@@ -746,7 +783,7 @@ class NeuralMemory(Module):
746
783
  self,
747
784
  seq,
748
785
  store_seq = None,
749
- state: NeuralMemCache | None = None,
786
+ state: NeuralMemState | None = None,
750
787
  prev_weights = None
751
788
  ):
752
789
  if seq.ndim == 2:
@@ -37,6 +37,7 @@ NUM_LONGTERM_MEM = 4
37
37
  NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural memory, can add more
38
38
  NEURAL_MEM_GATE_ATTN_OUTPUT = False
39
39
  NEURAL_MEM_MOMENTUM = True
40
+ NEURAL_MEM_MOMENTUM_ORDER = 1
40
41
  NEURAL_MEM_QK_NORM = True
41
42
  NEURAL_MEM_MAX_LR = 1e-1
42
43
  USE_MEM_ATTENTION_MODEL = False
@@ -115,6 +116,7 @@ model = MemoryAsContextTransformer(
115
116
  attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
116
117
  qk_rmsnorm = NEURAL_MEM_QK_NORM,
117
118
  momentum = NEURAL_MEM_MOMENTUM,
119
+ momentum_order = NEURAL_MEM_MOMENTUM_ORDER,
118
120
  default_step_transform_max_lr = NEURAL_MEM_MAX_LR,
119
121
  use_accelerated_scan = USE_ACCELERATED_SCAN,
120
122
  per_parameter_lr_modulation = MEMORY_MODEL_PER_LAYER_LEARNED_LR
File without changes
File without changes
File without changes