titans-pytorch 0.3.9__tar.gz → 0.3.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.9
3
+ Version: 0.3.10
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.3.9"
3
+ version = "0.3.10"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -60,6 +60,26 @@ def test_titans(
60
60
 
61
61
  assert seq.shape == retrieved.shape
62
62
 
63
+ @pytest.mark.parametrize('learned_momentum_combine', (False, True))
64
+ def test_titans_second_order_momentum(
65
+ learned_momentum_combine
66
+ ):
67
+
68
+ mem = NeuralMemory(
69
+ dim = 384,
70
+ dim_head = 64,
71
+ heads = 2,
72
+ chunk_size = 1,
73
+ batch_size = 2,
74
+ momentum_order = 2,
75
+ learned_momentum_combine = learned_momentum_combine
76
+ )
77
+
78
+ seq = torch.randn(2, 5, 384)
79
+
80
+ parallel_retrieved, state = mem(seq)
81
+ assert seq.shape == parallel_retrieved.shape
82
+
63
83
  def test_titans_attn_memory():
64
84
  from titans_pytorch.memory_models import MemoryAttention
65
85
 
@@ -318,12 +338,16 @@ def test_flex(
318
338
 
319
339
  assert torch.allclose(out_flex, out_non_flex, atol = 1e-5)
320
340
 
321
- @torch_default_dtype(torch.float64)
322
- def test_assoc_scan():
341
+ @pytest.mark.parametrize('use_accelerated', (True, False))
342
+ def test_assoc_scan(
343
+ use_accelerated
344
+ ):
323
345
  from titans_pytorch.neural_memory import AssocScan
324
- torch.set_default_dtype(torch.float64)
325
346
 
326
- scan = AssocScan()
347
+ if use_accelerated and not torch.cuda.is_available():
348
+ pytest.skip()
349
+
350
+ scan = AssocScan(use_accelerated = use_accelerated)
327
351
 
328
352
  seq_len = 128
329
353
  mid_point = seq_len // 2
@@ -331,6 +355,10 @@ def test_assoc_scan():
331
355
  gates = torch.randn(2, seq_len, 16).sigmoid()
332
356
  inputs = torch.randn(2, seq_len, 16)
333
357
 
358
+ if use_accelerated:
359
+ gates = gates.cuda()
360
+ inputs = inputs.cuda()
361
+
334
362
  output = scan(gates, inputs)
335
363
 
336
364
  gates1, gates2 = gates[:, :mid_point], gates[:, mid_point:]
@@ -341,4 +369,4 @@ def test_assoc_scan():
341
369
  second_half = scan(gates2, inputs2, prev = first_half[:, -1])
342
370
  assert second_half.shape == inputs2.shape
343
371
 
344
- assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-6)
372
+ assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-5)
@@ -3,6 +3,7 @@ from typing import Callable
3
3
 
4
4
  import math
5
5
  from functools import partial
6
+ from itertools import zip_longest
6
7
  from collections import namedtuple
7
8
 
8
9
  import torch
@@ -21,16 +22,19 @@ from titans_pytorch.memory_models import(
21
22
  )
22
23
 
23
24
  import einx
24
- from einops import rearrange, repeat, reduce, pack, unpack
25
+ from einops import einsum, rearrange, repeat, reduce, pack, unpack
25
26
  from einops.layers.torch import Rearrange, Reduce
26
27
 
27
28
  """
28
29
  ein notation:
29
30
  b - batch
31
+ h - heads
32
+ bh - batch and heads
30
33
  n - sequence
31
34
  d - feature dimension
32
35
  c - intra-chunk
33
36
  w - num memory network weight parameters
37
+ o - momentum orders
34
38
  """
35
39
 
36
40
  LinearNoBias = partial(Linear, bias = False)
@@ -224,6 +228,8 @@ class NeuralMemory(Module):
224
228
  per_head_learned_parameters = True,
225
229
  attn_pool_chunks = False,
226
230
  momentum = True,
231
+ momentum_order = 1,
232
+ learned_momentum_combine = False,
227
233
  pre_rmsnorm = True,
228
234
  post_rmsnorm = False,
229
235
  qk_rmsnorm = False,
@@ -367,12 +373,7 @@ class NeuralMemory(Module):
367
373
  else:
368
374
  self.reduce_to_chunk_rep = AttentionPool(dim, chunk_size = chunk_size)
369
375
 
370
- # learned adaptive learning rate and momentum
371
-
372
- self.to_momentum = Sequential(
373
- nn.Linear(dim, heads),
374
- Rearrange('b n h -> (b h) n 1')
375
- ) if momentum else None
376
+ # learned adaptive learning rate
376
377
 
377
378
  self.to_adaptive_step = Sequential(
378
379
  nn.Linear(dim, heads),
@@ -384,6 +385,26 @@ class NeuralMemory(Module):
384
385
 
385
386
  self.adaptive_step_transform = adaptive_step_transform
386
387
 
388
+ # momentum related
389
+
390
+ self.to_momentum = Sequential(
391
+ nn.Linear(dim, heads * momentum_order),
392
+ Rearrange('b n (h o) -> o (b h) n 1', o = momentum_order)
393
+ ) if momentum else None
394
+
395
+ self.momentum_order = momentum_order
396
+ self.to_learned_momentum_combine = None
397
+
398
+ if learned_momentum_combine:
399
+ assert momentum
400
+ assert momentum_order > 1, 'only second order momentum allowed for now, but may allow learned combination of zeroth'
401
+
402
+ self.to_learned_momentum_combine = Sequential(
403
+ nn.Linear(dim, heads * momentum_order),
404
+ nn.Softmax(dim = -1),
405
+ Rearrange('b n (h o) -> o (b h) n', h = heads)
406
+ )
407
+
387
408
  # per layer learning rate modulation
388
409
 
389
410
  self.to_layer_modulation = Sequential(
@@ -463,9 +484,9 @@ class NeuralMemory(Module):
463
484
  zeros = self.memory_model_parameter_dict.clone().zero_()
464
485
 
465
486
  if self.per_head_learned_parameters:
466
- zeros = repeat_dict_values(zeros, 'h ... -> (b h) ...', b = batch)
487
+ zeros = repeat_dict_values(zeros, 'h ... -> o (b h) ...', b = batch, o = self.momentum_order)
467
488
  else:
468
- zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
489
+ zeros = repeat_dict_values(zeros, '... -> o bh ...', bh = batch * self.heads, o = self.momentum_order)
469
490
 
470
491
  return zeros
471
492
 
@@ -518,6 +539,11 @@ class NeuralMemory(Module):
518
539
  if has_momentum:
519
540
  adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
520
541
 
542
+ learned_combine = exists(self.to_learned_momentum_combine)
543
+
544
+ if learned_combine:
545
+ combine_momentums = self.to_learned_momentum_combine(chunked_seq)
546
+
521
547
  if need_layer_lr_mod:
522
548
  layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation
523
549
 
@@ -585,7 +611,7 @@ class NeuralMemory(Module):
585
611
 
586
612
  # negative gradients, adaptive lr already applied as loss weight
587
613
 
588
- surprises = grads.apply(lambda t: -t)
614
+ surprises = grads.mul(-1)
589
615
 
590
616
  # past states
591
617
 
@@ -611,7 +637,6 @@ class NeuralMemory(Module):
611
637
 
612
638
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
613
639
 
614
- next_momentum = TensorDict() if has_momentum else None
615
640
  updates = TensorDict()
616
641
 
617
642
  next_last_update = TensorDict()
@@ -624,20 +649,34 @@ class NeuralMemory(Module):
624
649
  # derive momentum with associative scan - eq (10)
625
650
 
626
651
  if has_momentum:
652
+ momentum = surprise
653
+
654
+ momentums = [] # stores all momentum orders starting with first, to generalize to Nth order momentum
655
+
627
656
  last_momentum = past_last_momentum[param_name]
628
- update = self.assoc_scan(adaptive_momentum, surprise, prev = last_momentum) # momentum is S / surprise in the paper
629
- momentum = update
630
- next_last_momentum[param_name] = momentum[:, -1]
657
+
658
+ # go from first order momentum all the way to the Nth
659
+
660
+ for one_adaptive_momentum, one_last_momentum in zip_longest(adaptive_momentum, last_momentum):
661
+ momentum = self.assoc_scan(one_adaptive_momentum, momentum, prev = one_last_momentum) # momentum is S / surprise in the paper
662
+
663
+ momentums.append(momentum)
664
+
665
+ momentums = torch.stack(momentums)
666
+
667
+ next_last_momentum[param_name] = momentums[:, :, -1] # momentums shape is Float['o bh n 1']
668
+
669
+ if not learned_combine:
670
+ update = momentums[-1]
671
+ else:
672
+ update = einsum(combine_momentums, momentums, 'o b n, o b n ... -> b n ...')
631
673
 
632
674
  # use associative scan again for learned forgetting (weight decay) - eq (13)
633
675
 
634
676
  update = self.assoc_scan(1. - decay_factor, update, prev = last_update, remove_prev = False)
635
- next_last_update[param_name] = update[:, -1]
636
677
 
637
678
  updates[param_name] = update
638
-
639
- if has_momentum:
640
- next_momentum[param_name] = momentum
679
+ next_last_update[param_name] = update[:, -1]
641
680
 
642
681
  # determine next state for the storing of memories
643
682
 
File without changes
File without changes
File without changes