titans-pytorch 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,7 @@ from typing import Callable
3
3
 
4
4
  import math
5
5
  from functools import partial
6
+ from itertools import zip_longest
6
7
  from collections import namedtuple
7
8
 
8
9
  import torch
@@ -21,16 +22,19 @@ from titans_pytorch.memory_models import(
21
22
  )
22
23
 
23
24
  import einx
24
- from einops import rearrange, repeat, reduce, pack, unpack
25
+ from einops import einsum, rearrange, repeat, reduce, pack, unpack
25
26
  from einops.layers.torch import Rearrange, Reduce
26
27
 
27
28
  """
28
29
  ein notation:
29
30
  b - batch
31
+ h - heads
32
+ bh - batch and heads
30
33
  n - sequence
31
34
  d - feature dimension
32
35
  c - intra-chunk
33
36
  w - num memory network weight parameters
37
+ o - momentum orders
34
38
  """
35
39
 
36
40
  LinearNoBias = partial(Linear, bias = False)
@@ -224,6 +228,8 @@ class NeuralMemory(Module):
224
228
  per_head_learned_parameters = True,
225
229
  attn_pool_chunks = False,
226
230
  momentum = True,
231
+ momentum_order = 1,
232
+ learned_momentum_combine = False,
227
233
  pre_rmsnorm = True,
228
234
  post_rmsnorm = False,
229
235
  qk_rmsnorm = False,
@@ -367,12 +373,7 @@ class NeuralMemory(Module):
367
373
  else:
368
374
  self.reduce_to_chunk_rep = AttentionPool(dim, chunk_size = chunk_size)
369
375
 
370
- # learned adaptive learning rate and momentum
371
-
372
- self.to_momentum = Sequential(
373
- nn.Linear(dim, heads),
374
- Rearrange('b n h -> (b h) n 1')
375
- ) if momentum else None
376
+ # learned adaptive learning rate
376
377
 
377
378
  self.to_adaptive_step = Sequential(
378
379
  nn.Linear(dim, heads),
@@ -384,6 +385,26 @@ class NeuralMemory(Module):
384
385
 
385
386
  self.adaptive_step_transform = adaptive_step_transform
386
387
 
388
+ # momentum related
389
+
390
+ self.to_momentum = Sequential(
391
+ nn.Linear(dim, heads * momentum_order),
392
+ Rearrange('b n (h o) -> o (b h) n 1', o = momentum_order)
393
+ ) if momentum else None
394
+
395
+ self.momentum_order = momentum_order
396
+ self.to_learned_momentum_combine = None
397
+
398
+ if learned_momentum_combine:
399
+ assert momentum
400
+ assert momentum_order > 1, 'only second order momentum allowed for now, but may allow learned combination of zeroth'
401
+
402
+ self.to_learned_momentum_combine = Sequential(
403
+ nn.Linear(dim, heads * momentum_order),
404
+ nn.Softmax(dim = -1),
405
+ Rearrange('b n (h o) -> o (b h) n', h = heads)
406
+ )
407
+
387
408
  # per layer learning rate modulation
388
409
 
389
410
  self.to_layer_modulation = Sequential(
@@ -463,9 +484,9 @@ class NeuralMemory(Module):
463
484
  zeros = self.memory_model_parameter_dict.clone().zero_()
464
485
 
465
486
  if self.per_head_learned_parameters:
466
- zeros = repeat_dict_values(zeros, 'h ... -> (b h) ...', b = batch)
487
+ zeros = repeat_dict_values(zeros, 'h ... -> o (b h) ...', b = batch, o = self.momentum_order)
467
488
  else:
468
- zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
489
+ zeros = repeat_dict_values(zeros, '... -> o bh ...', bh = batch * self.heads, o = self.momentum_order)
469
490
 
470
491
  return zeros
471
492
 
@@ -518,6 +539,11 @@ class NeuralMemory(Module):
518
539
  if has_momentum:
519
540
  adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
520
541
 
542
+ learned_combine = exists(self.to_learned_momentum_combine)
543
+
544
+ if learned_combine:
545
+ combine_momentums = self.to_learned_momentum_combine(chunked_seq)
546
+
521
547
  if need_layer_lr_mod:
522
548
  layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation
523
549
 
@@ -585,7 +611,7 @@ class NeuralMemory(Module):
585
611
 
586
612
  # negative gradients, adaptive lr already applied as loss weight
587
613
 
588
- surprises = grads.apply(lambda t: -t)
614
+ surprises = grads.mul(-1)
589
615
 
590
616
  # past states
591
617
 
@@ -611,7 +637,6 @@ class NeuralMemory(Module):
611
637
 
612
638
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
613
639
 
614
- next_momentum = TensorDict() if has_momentum else None
615
640
  updates = TensorDict()
616
641
 
617
642
  next_last_update = TensorDict()
@@ -624,20 +649,34 @@ class NeuralMemory(Module):
624
649
  # derive momentum with associative scan - eq (10)
625
650
 
626
651
  if has_momentum:
652
+ momentum = surprise
653
+
654
+ momentums = [] # stores all momentum orders starting with first, to generalize to Nth order momentum
655
+
627
656
  last_momentum = past_last_momentum[param_name]
628
- update = self.assoc_scan(adaptive_momentum, surprise, prev = last_momentum) # momentum is S / surprise in the paper
629
- momentum = update
630
- next_last_momentum[param_name] = momentum[:, -1]
657
+
658
+ # go from first order momentum all the way to the Nth
659
+
660
+ for one_adaptive_momentum, one_last_momentum in zip_longest(adaptive_momentum, last_momentum):
661
+ momentum = self.assoc_scan(one_adaptive_momentum, momentum, prev = one_last_momentum) # momentum is S / surprise in the paper
662
+
663
+ momentums.append(momentum)
664
+
665
+ momentums = torch.stack(momentums)
666
+
667
+ next_last_momentum[param_name] = momentums[:, :, -1] # momentums shape is Float['o bh n 1']
668
+
669
+ if not learned_combine:
670
+ update = momentums[-1]
671
+ else:
672
+ update = einsum(combine_momentums, momentums, 'o b n, o b n ... -> b n ...')
631
673
 
632
674
  # use associative scan again for learned forgetting (weight decay) - eq (13)
633
675
 
634
676
  update = self.assoc_scan(1. - decay_factor, update, prev = last_update, remove_prev = False)
635
- next_last_update[param_name] = update[:, -1]
636
677
 
637
678
  updates[param_name] = update
638
-
639
- if has_momentum:
640
- next_momentum[param_name] = momentum
679
+ next_last_update[param_name] = update[:, -1]
641
680
 
642
681
  # determine next state for the storing of memories
643
682
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.9
3
+ Version: 0.3.10
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
+ titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
+ titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
+ titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
+ titans_pytorch/neural_memory.py,sha256=BeOnq41gjZeq-XJFjkHE44F9dLzsg9mm36EBYZ4wHMA,28814
6
+ titans_pytorch-0.3.10.dist-info/METADATA,sha256=sA_Dx_x5RMcpz5-vUPDHuz__tHYfKzs4W_BgY4CHPdk,6816
7
+ titans_pytorch-0.3.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.10.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
- titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
- titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
- titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=2Ffq5fob6_vJVUs1jIai3BJhfsfypKffJEb0QwRRdMk,27325
6
- titans_pytorch-0.3.9.dist-info/METADATA,sha256=EUuuqHl8jPUPIG8m7xV5esN_2yDNaPdD-H8qFeDxWGo,6815
7
- titans_pytorch-0.3.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.9.dist-info/RECORD,,