titans-pytorch 0.1.11__tar.gz → 0.1.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.11
3
+ Version: 0.1.14
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -137,12 +137,13 @@ $ python train_mac.py
137
137
  ```
138
138
 
139
139
  ```bibtex
140
- @software{Kyrylov_Accelerated_Scan_2024,
141
- author = {Kyrylov, Volodymyr},
142
- doi = {10.5281/zenodo.10600962},
143
- title = {Accelerated Scan},
144
- version = {0.1.2},
145
- year = {2024}
140
+ @article{Sun2024LearningT,
141
+ title = {Learning to (Learn at Test Time): RNNs with Expressive Hidden States},
142
+ author = {Yu Sun and Xinhao Li and Karan Dalal and Jiarui Xu and Arjun Vikram and Genghan Zhang and Yann Dubois and Xinlei Chen and Xiaolong Wang and Oluwasanmi Koyejo and Tatsunori Hashimoto and Carlos Guestrin},
143
+ journal = {ArXiv},
144
+ year = {2024},
145
+ volume = {abs/2407.04620},
146
+ url = {https://api.semanticscholar.org/CorpusID:271039606}
146
147
  }
147
148
  ```
148
149
 
@@ -154,3 +155,44 @@ $ python train_mac.py
154
155
  url = {https://api.semanticscholar.org/CorpusID:274598177}
155
156
  }
156
157
  ```
158
+
159
+ ```bibtex
160
+ @inproceedings{Nguyen2024TurningUT,
161
+ title = {Turning Up the Heat: Min-p Sampling for Creative and Coherent LLM Outputs},
162
+ author = {Minh Nguyen and Andrew Baker and Clement Neo and Allen Roush and Andreas Kirsch and Ravid Shwartz-Ziv},
163
+ year = {2024},
164
+ url = {https://api.semanticscholar.org/CorpusID:270870613}
165
+ }
166
+ ```
167
+
168
+ ```bibtex
169
+ @article{Zhu2024HyperConnections,
170
+ title = {Hyper-Connections},
171
+ author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
172
+ journal = {ArXiv},
173
+ year = {2024},
174
+ volume = {abs/2409.19606},
175
+ url = {https://api.semanticscholar.org/CorpusID:272987528}
176
+ }
177
+ ```
178
+
179
+ ```bibtex
180
+ @article{Zhou2024ValueRL,
181
+ title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
182
+ author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
183
+ journal = {ArXiv},
184
+ year = {2024},
185
+ volume = {abs/2410.17897},
186
+ url = {https://api.semanticscholar.org/CorpusID:273532030}
187
+ }
188
+ ```
189
+
190
+ ```bibtex
191
+ @software{Kyrylov_Accelerated_Scan_2024,
192
+ author = {Kyrylov, Volodymyr},
193
+ doi = {10.5281/zenodo.10600962},
194
+ title = {Accelerated Scan},
195
+ version = {0.1.2},
196
+ year = {2024}
197
+ }
198
+ ```
@@ -83,12 +83,13 @@ $ python train_mac.py
83
83
  ```
84
84
 
85
85
  ```bibtex
86
- @software{Kyrylov_Accelerated_Scan_2024,
87
- author = {Kyrylov, Volodymyr},
88
- doi = {10.5281/zenodo.10600962},
89
- title = {Accelerated Scan},
90
- version = {0.1.2},
91
- year = {2024}
86
+ @article{Sun2024LearningT,
87
+ title = {Learning to (Learn at Test Time): RNNs with Expressive Hidden States},
88
+ author = {Yu Sun and Xinhao Li and Karan Dalal and Jiarui Xu and Arjun Vikram and Genghan Zhang and Yann Dubois and Xinlei Chen and Xiaolong Wang and Oluwasanmi Koyejo and Tatsunori Hashimoto and Carlos Guestrin},
89
+ journal = {ArXiv},
90
+ year = {2024},
91
+ volume = {abs/2407.04620},
92
+ url = {https://api.semanticscholar.org/CorpusID:271039606}
92
93
  }
93
94
  ```
94
95
 
@@ -100,3 +101,44 @@ $ python train_mac.py
100
101
  url = {https://api.semanticscholar.org/CorpusID:274598177}
101
102
  }
102
103
  ```
104
+
105
+ ```bibtex
106
+ @inproceedings{Nguyen2024TurningUT,
107
+ title = {Turning Up the Heat: Min-p Sampling for Creative and Coherent LLM Outputs},
108
+ author = {Minh Nguyen and Andrew Baker and Clement Neo and Allen Roush and Andreas Kirsch and Ravid Shwartz-Ziv},
109
+ year = {2024},
110
+ url = {https://api.semanticscholar.org/CorpusID:270870613}
111
+ }
112
+ ```
113
+
114
+ ```bibtex
115
+ @article{Zhu2024HyperConnections,
116
+ title = {Hyper-Connections},
117
+ author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
118
+ journal = {ArXiv},
119
+ year = {2024},
120
+ volume = {abs/2409.19606},
121
+ url = {https://api.semanticscholar.org/CorpusID:272987528}
122
+ }
123
+ ```
124
+
125
+ ```bibtex
126
+ @article{Zhou2024ValueRL,
127
+ title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
128
+ author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
129
+ journal = {ArXiv},
130
+ year = {2024},
131
+ volume = {abs/2410.17897},
132
+ url = {https://api.semanticscholar.org/CorpusID:273532030}
133
+ }
134
+ ```
135
+
136
+ ```bibtex
137
+ @software{Kyrylov_Accelerated_Scan_2024,
138
+ author = {Kyrylov, Volodymyr},
139
+ doi = {10.5281/zenodo.10600962},
140
+ title = {Accelerated Scan},
141
+ version = {0.1.2},
142
+ year = {2024}
143
+ }
144
+ ```
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.11"
3
+ version = "0.1.14"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -12,6 +12,7 @@ def exists(v):
12
12
  @pytest.mark.parametrize('silu', (False, True))
13
13
  @pytest.mark.parametrize('learned_mem_model_weights', (False, True))
14
14
  @pytest.mark.parametrize('attn_pool_chunks', (False, True))
15
+ @pytest.mark.parametrize('momentum', (False, True))
15
16
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
16
17
  @pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
17
18
  def test_titans(
@@ -19,6 +20,7 @@ def test_titans(
19
20
  silu,
20
21
  learned_mem_model_weights,
21
22
  attn_pool_chunks,
23
+ momentum,
22
24
  max_grad_norm,
23
25
  per_parameter_lr_modulation
24
26
  ):
@@ -28,6 +30,7 @@ def test_titans(
28
30
  activation = nn.SiLU() if silu else None,
29
31
  attn_pool_chunks = attn_pool_chunks,
30
32
  max_grad_norm = max_grad_norm,
33
+ momentum = momentum,
31
34
  per_parameter_lr_modulation = per_parameter_lr_modulation,
32
35
  learned_mem_model_weights = learned_mem_model_weights
33
36
  )
@@ -66,6 +69,19 @@ def test_retrieve_store_diff_seq():
66
69
 
67
70
  assert retrieve_seq.shape == retrieved.shape
68
71
 
72
+ def test_overriding_chunk_size():
73
+ mem = NeuralMemory(
74
+ dim = 384,
75
+ chunk_size = 64,
76
+ )
77
+
78
+ seq = torch.randn(2, 128 * 16, 384)
79
+ store_seq = torch.randn(2, 128 * 8, 384)
80
+
81
+ retrieved = mem(seq, store_seq, chunk_size = 16, store_chunk_size = 8)
82
+
83
+ assert seq.shape == retrieved.shape
84
+
69
85
  @pytest.mark.parametrize('seq_len', (1023, 17))
70
86
  @pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
71
87
  @pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
@@ -99,7 +99,23 @@ class MultiheadRMSNorm(Module):
99
99
  def forward(self, x):
100
100
  return self.rmsnorm(x) * (self.gamma + 1.)
101
101
 
102
- # attention pool
102
+ # chunk pooling
103
+
104
+ class AveragePool(Module):
105
+ def __init__(
106
+ self,
107
+ chunk_size
108
+ ):
109
+ super().__init__()
110
+ self.chunk_size = chunk_size
111
+
112
+ def forward(
113
+ self,
114
+ x,
115
+ chunk_size = None
116
+ ):
117
+ chunk_size = default(chunk_size, self.chunk_size)
118
+ return reduce(x, 'b (n c) d -> b n d', 'mean', c = chunk_size)
103
119
 
104
120
  class AttentionPool(Module):
105
121
  def __init__(
@@ -111,7 +127,7 @@ class AttentionPool(Module):
111
127
  taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
112
128
  """
113
129
  super().__init__()
114
- self.split_chunks = Rearrange('b (n c) d -> b n c d', c = chunk_size)
130
+ self.chunk_size = chunk_size
115
131
  self.to_attn_logits = nn.Linear(dim, dim)
116
132
 
117
133
  # default to average pool
@@ -121,9 +137,13 @@ class AttentionPool(Module):
121
137
 
122
138
  def forward(
123
139
  self,
124
- x
140
+ x,
141
+ chunk_size = None
125
142
  ):
126
- x = self.split_chunks(x)
143
+ chunk_size = default(chunk_size, self.chunk_size)
144
+
145
+ x = rearrange(x, 'b (n c) d -> b n c d', c = chunk_size)
146
+
127
147
  attn_logits = self.to_attn_logits(x)
128
148
 
129
149
  attn = attn_logits.softmax(dim = -2)
@@ -303,6 +323,7 @@ class NeuralMemory(Module):
303
323
  per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
304
324
  max_mem_layer_modulation = 1e1, # max of 10.
305
325
  attn_pool_chunks = False,
326
+ momentum = True,
306
327
  pre_rmsnorm = True,
307
328
  post_rmsnorm = True,
308
329
  learned_mem_model_weights = True,
@@ -394,17 +415,16 @@ class NeuralMemory(Module):
394
415
  assert not (attn_pool_chunks and chunk_size == 1), '`attn_pool_chunks` cannot be set to True if `chunk_size` is set to 1'
395
416
 
396
417
  if not attn_pool_chunks:
397
- chunk_reduce_module = Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size)
418
+ self.reduce_to_chunk_rep = AveragePool(chunk_size = chunk_size)
398
419
  else:
399
- chunk_reduce_module = AttentionPool(dim, chunk_size = chunk_size)
420
+ self.reduce_to_chunk_rep = AttentionPool(dim, chunk_size = chunk_size)
400
421
 
401
422
  # learned adaptive learning rate and momentum
402
423
 
403
424
  self.to_momentum = Sequential(
404
- chunk_reduce_module,
405
425
  LinearNoBias(dim, heads),
406
426
  Rearrange('b n h -> (b h) n 1')
407
- )
427
+ ) if momentum else None
408
428
 
409
429
  self.to_adaptive_step = Sequential(
410
430
  LinearNoBias(dim, heads),
@@ -419,7 +439,6 @@ class NeuralMemory(Module):
419
439
  # per layer learning rate modulation
420
440
 
421
441
  self.to_layer_modulation = Sequential(
422
- chunk_reduce_module,
423
442
  LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
424
443
  Rearrange('b n (h w) -> w (b h) n', h = heads),
425
444
  nn.Sigmoid()
@@ -434,7 +453,6 @@ class NeuralMemory(Module):
434
453
  # weight decay factor
435
454
 
436
455
  self.to_decay_factor = Sequential(
437
- chunk_reduce_module,
438
456
  LinearNoBias(dim, heads),
439
457
  Rearrange('b n h -> (b h) n 1')
440
458
  )
@@ -445,12 +463,15 @@ class NeuralMemory(Module):
445
463
 
446
464
  self.register_buffer('zero', torch.tensor(0.), persistent = False)
447
465
 
448
- def init_weights_and_momentum(self):
466
+ def init_weights_and_momentum(self, zero_weights = False):
449
467
  params = TensorDict(dict(self.memory_model.named_parameters()))
450
468
 
451
- init_weights = params.clone().zero_()
469
+ init_weights = params
452
470
  init_momentum = params.clone().zero_()
453
471
 
472
+ if zero_weights:
473
+ init_weights = params.clone().zero_()
474
+
454
475
  return init_weights, init_momentum
455
476
 
456
477
  def init_empty_memory_embed(self, batch, seq_len):
@@ -460,9 +481,10 @@ class NeuralMemory(Module):
460
481
  self,
461
482
  seq,
462
483
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
463
- return_aux_kv_loss = False
484
+ return_aux_kv_loss = False,
485
+ chunk_size = None
464
486
  ):
465
- seq_len, chunk_size = seq.shape[-2], self.store_chunk_size
487
+ seq_len, chunk_size = seq.shape[-2], default(chunk_size, self.store_chunk_size)
466
488
 
467
489
  # handle edge case
468
490
 
@@ -479,27 +501,28 @@ class NeuralMemory(Module):
479
501
 
480
502
  seq = seq[:, :round_down_seq_len]
481
503
 
482
- # curr weights + past weights, in the case that the initial weights are learned
483
-
484
- curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
504
+ # get the weights of the memory network
485
505
 
486
506
  past_state = tuple(TensorDict(d) for d in past_state)
487
- past_weights, past_momentum = past_state
488
-
489
- curr_weights = curr_weights + past_weights
507
+ curr_weights, past_momentum = past_state
490
508
 
491
- # pack batch and sequence dimension
509
+ # derive learned hparams for optimization of memory network
492
510
 
493
511
  adaptive_lr = self.to_adaptive_step(seq)
494
512
  adaptive_lr = self.adaptive_step_transform(adaptive_lr)
495
513
 
496
- adaptive_momentum = self.to_momentum(seq).sigmoid()
497
- decay_factor = self.to_decay_factor(seq).sigmoid()
514
+ chunked_seq = self.reduce_to_chunk_rep(seq, chunk_size = chunk_size)
515
+
516
+ decay_factor = self.to_decay_factor(chunked_seq).sigmoid()
498
517
 
499
518
  need_layer_lr_mod = exists(self.to_layer_modulation)
519
+ has_momentum = exists(self.to_momentum)
520
+
521
+ if has_momentum:
522
+ adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
500
523
 
501
524
  if need_layer_lr_mod:
502
- layer_lr_mod = self.to_layer_modulation(seq) * self.max_mem_layer_modulation
525
+ layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation
503
526
 
504
527
  # keys and values
505
528
 
@@ -575,23 +598,29 @@ class NeuralMemory(Module):
575
598
 
576
599
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
577
600
 
578
- next_momentum = TensorDict()
601
+ next_momentum = TensorDict() if has_momentum else None
579
602
  updates = TensorDict()
580
603
 
581
604
  for param_name, surprise in surprises.items():
582
605
 
583
606
  surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
584
607
 
608
+ update = surprise
609
+
585
610
  # derive momentum with associative scan - eq (10)
586
611
 
587
- momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
612
+ if has_momentum:
613
+ update = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
614
+ momentum = update
588
615
 
589
616
  # use associative scan again for learned forgetting (weight decay) - eq (13)
590
617
 
591
- update = scan_fn(1. - decay_factor, momentum)
618
+ update = scan_fn(1. - decay_factor, update)
592
619
 
593
620
  updates[param_name] = inverse_pack(update)
594
- next_momentum[param_name] = inverse_pack(momentum)
621
+
622
+ if has_momentum:
623
+ next_momentum[param_name] = inverse_pack(momentum)
595
624
 
596
625
  # compute the next weight per batch
597
626
 
@@ -606,8 +635,9 @@ class NeuralMemory(Module):
606
635
  self,
607
636
  seq,
608
637
  past_weights: dict[str, Tensor] | None = None,
638
+ chunk_size = None
609
639
  ):
610
- chunk_size = self.retrieve_chunk_size
640
+ chunk_size = default(chunk_size, self.retrieve_chunk_size)
611
641
  batch, seq_len = seq.shape[:2]
612
642
 
613
643
  seq = self.retrieve_norm(seq)
@@ -680,7 +710,9 @@ class NeuralMemory(Module):
680
710
  seq,
681
711
  store_seq = None,
682
712
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
683
- return_aux_kv_loss = False
713
+ return_aux_kv_loss = False,
714
+ chunk_size = None,
715
+ store_chunk_size = None
684
716
  ):
685
717
  batch, seq_len = seq.shape[:2]
686
718
 
@@ -699,12 +731,13 @@ class NeuralMemory(Module):
699
731
  past_state = self.init_weights_and_momentum()
700
732
 
701
733
  store_seq = default(store_seq, seq)
734
+ store_chunk_size = default(store_chunk_size, chunk_size)
702
735
 
703
- updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
736
+ updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
704
737
 
705
738
  past_weights, _ = past_state
706
739
 
707
- retrieved = self.retrieve_memories(seq, past_weights + updates)
740
+ retrieved = self.retrieve_memories(seq, past_weights + updates, chunk_size = chunk_size)
708
741
 
709
742
  if not return_aux_kv_loss:
710
743
  return retrieved
@@ -31,6 +31,7 @@ NUM_PERSIST_MEM = 4
31
31
  NUM_LONGTERM_MEM = 4
32
32
  NEURAL_MEM_LAYERS = (2, 4)
33
33
  NEURAL_MEM_GATE_ATTN_OUTPUT = True
34
+ NEURAL_MEM_MOMENTUM = True
34
35
  WINDOW_SIZE = 32
35
36
  NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
36
37
  SLIDING_WINDOWS = True
@@ -88,6 +89,7 @@ model = MemoryAsContextTransformer(
88
89
  dim_head = 64,
89
90
  heads = 4,
90
91
  attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
92
+ momentum = NEURAL_MEM_MOMENTUM,
91
93
  use_accelerated_scan = USE_ACCELERATED_SCAN,
92
94
  learned_mem_model_weights = LEARNED_MEM_MODEL_WEIGHTS,
93
95
  default_model_kwargs = dict(
File without changes