titans-pytorch 0.1.12__tar.gz → 0.1.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.12
3
+ Version: 0.1.14
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.12"
3
+ version = "0.1.14"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -12,6 +12,7 @@ def exists(v):
12
12
  @pytest.mark.parametrize('silu', (False, True))
13
13
  @pytest.mark.parametrize('learned_mem_model_weights', (False, True))
14
14
  @pytest.mark.parametrize('attn_pool_chunks', (False, True))
15
+ @pytest.mark.parametrize('momentum', (False, True))
15
16
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
16
17
  @pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
17
18
  def test_titans(
@@ -19,6 +20,7 @@ def test_titans(
19
20
  silu,
20
21
  learned_mem_model_weights,
21
22
  attn_pool_chunks,
23
+ momentum,
22
24
  max_grad_norm,
23
25
  per_parameter_lr_modulation
24
26
  ):
@@ -28,6 +30,7 @@ def test_titans(
28
30
  activation = nn.SiLU() if silu else None,
29
31
  attn_pool_chunks = attn_pool_chunks,
30
32
  max_grad_norm = max_grad_norm,
33
+ momentum = momentum,
31
34
  per_parameter_lr_modulation = per_parameter_lr_modulation,
32
35
  learned_mem_model_weights = learned_mem_model_weights
33
36
  )
@@ -323,6 +323,7 @@ class NeuralMemory(Module):
323
323
  per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
324
324
  max_mem_layer_modulation = 1e1, # max of 10.
325
325
  attn_pool_chunks = False,
326
+ momentum = True,
326
327
  pre_rmsnorm = True,
327
328
  post_rmsnorm = True,
328
329
  learned_mem_model_weights = True,
@@ -423,7 +424,7 @@ class NeuralMemory(Module):
423
424
  self.to_momentum = Sequential(
424
425
  LinearNoBias(dim, heads),
425
426
  Rearrange('b n h -> (b h) n 1')
426
- )
427
+ ) if momentum else None
427
428
 
428
429
  self.to_adaptive_step = Sequential(
429
430
  LinearNoBias(dim, heads),
@@ -462,12 +463,15 @@ class NeuralMemory(Module):
462
463
 
463
464
  self.register_buffer('zero', torch.tensor(0.), persistent = False)
464
465
 
465
- def init_weights_and_momentum(self):
466
+ def init_weights_and_momentum(self, zero_weights = False):
466
467
  params = TensorDict(dict(self.memory_model.named_parameters()))
467
468
 
468
- init_weights = params.clone().zero_()
469
+ init_weights = params
469
470
  init_momentum = params.clone().zero_()
470
471
 
472
+ if zero_weights:
473
+ init_weights = params.clone().zero_()
474
+
471
475
  return init_weights, init_momentum
472
476
 
473
477
  def init_empty_memory_embed(self, batch, seq_len):
@@ -497,14 +501,10 @@ class NeuralMemory(Module):
497
501
 
498
502
  seq = seq[:, :round_down_seq_len]
499
503
 
500
- # curr weights + past weights, in the case that the initial weights are learned
501
-
502
- curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
504
+ # get the weights of the memory network
503
505
 
504
506
  past_state = tuple(TensorDict(d) for d in past_state)
505
- past_weights, past_momentum = past_state
506
-
507
- curr_weights = curr_weights + past_weights
507
+ curr_weights, past_momentum = past_state
508
508
 
509
509
  # derive learned hparams for optimization of memory network
510
510
 
@@ -513,10 +513,13 @@ class NeuralMemory(Module):
513
513
 
514
514
  chunked_seq = self.reduce_to_chunk_rep(seq, chunk_size = chunk_size)
515
515
 
516
- adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
517
516
  decay_factor = self.to_decay_factor(chunked_seq).sigmoid()
518
517
 
519
518
  need_layer_lr_mod = exists(self.to_layer_modulation)
519
+ has_momentum = exists(self.to_momentum)
520
+
521
+ if has_momentum:
522
+ adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
520
523
 
521
524
  if need_layer_lr_mod:
522
525
  layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation
@@ -595,23 +598,29 @@ class NeuralMemory(Module):
595
598
 
596
599
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
597
600
 
598
- next_momentum = TensorDict()
601
+ next_momentum = TensorDict() if has_momentum else None
599
602
  updates = TensorDict()
600
603
 
601
604
  for param_name, surprise in surprises.items():
602
605
 
603
606
  surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
604
607
 
608
+ update = surprise
609
+
605
610
  # derive momentum with associative scan - eq (10)
606
611
 
607
- momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
612
+ if has_momentum:
613
+ update = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
614
+ momentum = update
608
615
 
609
616
  # use associative scan again for learned forgetting (weight decay) - eq (13)
610
617
 
611
- update = scan_fn(1. - decay_factor, momentum)
618
+ update = scan_fn(1. - decay_factor, update)
612
619
 
613
620
  updates[param_name] = inverse_pack(update)
614
- next_momentum[param_name] = inverse_pack(momentum)
621
+
622
+ if has_momentum:
623
+ next_momentum[param_name] = inverse_pack(momentum)
615
624
 
616
625
  # compute the next weight per batch
617
626
 
@@ -31,6 +31,7 @@ NUM_PERSIST_MEM = 4
31
31
  NUM_LONGTERM_MEM = 4
32
32
  NEURAL_MEM_LAYERS = (2, 4)
33
33
  NEURAL_MEM_GATE_ATTN_OUTPUT = True
34
+ NEURAL_MEM_MOMENTUM = True
34
35
  WINDOW_SIZE = 32
35
36
  NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
36
37
  SLIDING_WINDOWS = True
@@ -88,6 +89,7 @@ model = MemoryAsContextTransformer(
88
89
  dim_head = 64,
89
90
  heads = 4,
90
91
  attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
92
+ momentum = NEURAL_MEM_MOMENTUM,
91
93
  use_accelerated_scan = USE_ACCELERATED_SCAN,
92
94
  learned_mem_model_weights = LEARNED_MEM_MODEL_WEIGHTS,
93
95
  default_model_kwargs = dict(
File without changes