titans-pytorch 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -99,7 +99,23 @@ class MultiheadRMSNorm(Module):
99
99
  def forward(self, x):
100
100
  return self.rmsnorm(x) * (self.gamma + 1.)
101
101
 
102
- # attention pool
102
+ # chunk pooling
103
+
104
+ class AveragePool(Module):
105
+ def __init__(
106
+ self,
107
+ chunk_size
108
+ ):
109
+ super().__init__()
110
+ self.chunk_size = chunk_size
111
+
112
+ def forward(
113
+ self,
114
+ x,
115
+ chunk_size = None
116
+ ):
117
+ chunk_size = default(chunk_size, self.chunk_size)
118
+ return reduce(x, 'b (n c) d -> b n d', 'mean', c = chunk_size)
103
119
 
104
120
  class AttentionPool(Module):
105
121
  def __init__(
@@ -111,7 +127,7 @@ class AttentionPool(Module):
111
127
  taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
112
128
  """
113
129
  super().__init__()
114
- self.split_chunks = Rearrange('b (n c) d -> b n c d', c = chunk_size)
130
+ self.chunk_size = chunk_size
115
131
  self.to_attn_logits = nn.Linear(dim, dim)
116
132
 
117
133
  # default to average pool
@@ -121,9 +137,13 @@ class AttentionPool(Module):
121
137
 
122
138
  def forward(
123
139
  self,
124
- x
140
+ x,
141
+ chunk_size = None
125
142
  ):
126
- x = self.split_chunks(x)
143
+ chunk_size = default(chunk_size, self.chunk_size)
144
+
145
+ x = rearrange(x, 'b (n c) d -> b n c d', c = chunk_size)
146
+
127
147
  attn_logits = self.to_attn_logits(x)
128
148
 
129
149
  attn = attn_logits.softmax(dim = -2)
@@ -394,14 +414,13 @@ class NeuralMemory(Module):
394
414
  assert not (attn_pool_chunks and chunk_size == 1), '`attn_pool_chunks` cannot be set to True if `chunk_size` is set to 1'
395
415
 
396
416
  if not attn_pool_chunks:
397
- chunk_reduce_module = Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size)
417
+ self.reduce_to_chunk_rep = AveragePool(chunk_size = chunk_size)
398
418
  else:
399
- chunk_reduce_module = AttentionPool(dim, chunk_size = chunk_size)
419
+ self.reduce_to_chunk_rep = AttentionPool(dim, chunk_size = chunk_size)
400
420
 
401
421
  # learned adaptive learning rate and momentum
402
422
 
403
423
  self.to_momentum = Sequential(
404
- chunk_reduce_module,
405
424
  LinearNoBias(dim, heads),
406
425
  Rearrange('b n h -> (b h) n 1')
407
426
  )
@@ -419,7 +438,6 @@ class NeuralMemory(Module):
419
438
  # per layer learning rate modulation
420
439
 
421
440
  self.to_layer_modulation = Sequential(
422
- chunk_reduce_module,
423
441
  LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
424
442
  Rearrange('b n (h w) -> w (b h) n', h = heads),
425
443
  nn.Sigmoid()
@@ -434,7 +452,6 @@ class NeuralMemory(Module):
434
452
  # weight decay factor
435
453
 
436
454
  self.to_decay_factor = Sequential(
437
- chunk_reduce_module,
438
455
  LinearNoBias(dim, heads),
439
456
  Rearrange('b n h -> (b h) n 1')
440
457
  )
@@ -460,9 +477,10 @@ class NeuralMemory(Module):
460
477
  self,
461
478
  seq,
462
479
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
463
- return_aux_kv_loss = False
480
+ return_aux_kv_loss = False,
481
+ chunk_size = None
464
482
  ):
465
- seq_len, chunk_size = seq.shape[-2], self.store_chunk_size
483
+ seq_len, chunk_size = seq.shape[-2], default(chunk_size, self.store_chunk_size)
466
484
 
467
485
  # handle edge case
468
486
 
@@ -488,18 +506,20 @@ class NeuralMemory(Module):
488
506
 
489
507
  curr_weights = curr_weights + past_weights
490
508
 
491
- # pack batch and sequence dimension
509
+ # derive learned hparams for optimization of memory network
492
510
 
493
511
  adaptive_lr = self.to_adaptive_step(seq)
494
512
  adaptive_lr = self.adaptive_step_transform(adaptive_lr)
495
513
 
496
- adaptive_momentum = self.to_momentum(seq).sigmoid()
497
- decay_factor = self.to_decay_factor(seq).sigmoid()
514
+ chunked_seq = self.reduce_to_chunk_rep(seq, chunk_size = chunk_size)
515
+
516
+ adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
517
+ decay_factor = self.to_decay_factor(chunked_seq).sigmoid()
498
518
 
499
519
  need_layer_lr_mod = exists(self.to_layer_modulation)
500
520
 
501
521
  if need_layer_lr_mod:
502
- layer_lr_mod = self.to_layer_modulation(seq) * self.max_mem_layer_modulation
522
+ layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation
503
523
 
504
524
  # keys and values
505
525
 
@@ -606,8 +626,9 @@ class NeuralMemory(Module):
606
626
  self,
607
627
  seq,
608
628
  past_weights: dict[str, Tensor] | None = None,
629
+ chunk_size = None
609
630
  ):
610
- chunk_size = self.retrieve_chunk_size
631
+ chunk_size = default(chunk_size, self.retrieve_chunk_size)
611
632
  batch, seq_len = seq.shape[:2]
612
633
 
613
634
  seq = self.retrieve_norm(seq)
@@ -680,7 +701,9 @@ class NeuralMemory(Module):
680
701
  seq,
681
702
  store_seq = None,
682
703
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
683
- return_aux_kv_loss = False
704
+ return_aux_kv_loss = False,
705
+ chunk_size = None,
706
+ store_chunk_size = None
684
707
  ):
685
708
  batch, seq_len = seq.shape[:2]
686
709
 
@@ -699,12 +722,13 @@ class NeuralMemory(Module):
699
722
  past_state = self.init_weights_and_momentum()
700
723
 
701
724
  store_seq = default(store_seq, seq)
725
+ store_chunk_size = default(store_chunk_size, chunk_size)
702
726
 
703
- updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
727
+ updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
704
728
 
705
729
  past_weights, _ = past_state
706
730
 
707
- retrieved = self.retrieve_memories(seq, past_weights + updates)
731
+ retrieved = self.retrieve_memories(seq, past_weights + updates, chunk_size = chunk_size)
708
732
 
709
733
  if not return_aux_kv_loss:
710
734
  return retrieved
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.10
3
+ Version: 0.1.12
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -35,7 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
- Requires-Dist: axial-positional-embedding>=0.3.6
38
+ Requires-Dist: axial-positional-embedding>=0.3.7
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: hyper-connections>=0.1.8
@@ -137,12 +137,13 @@ $ python train_mac.py
137
137
  ```
138
138
 
139
139
  ```bibtex
140
- @software{Kyrylov_Accelerated_Scan_2024,
141
- author = {Kyrylov, Volodymyr},
142
- doi = {10.5281/zenodo.10600962},
143
- title = {Accelerated Scan},
144
- version = {0.1.2},
145
- year = {2024}
140
+ @article{Sun2024LearningT,
141
+ title = {Learning to (Learn at Test Time): RNNs with Expressive Hidden States},
142
+ author = {Yu Sun and Xinhao Li and Karan Dalal and Jiarui Xu and Arjun Vikram and Genghan Zhang and Yann Dubois and Xinlei Chen and Xiaolong Wang and Oluwasanmi Koyejo and Tatsunori Hashimoto and Carlos Guestrin},
143
+ journal = {ArXiv},
144
+ year = {2024},
145
+ volume = {abs/2407.04620},
146
+ url = {https://api.semanticscholar.org/CorpusID:271039606}
146
147
  }
147
148
  ```
148
149
 
@@ -154,3 +155,44 @@ $ python train_mac.py
154
155
  url = {https://api.semanticscholar.org/CorpusID:274598177}
155
156
  }
156
157
  ```
158
+
159
+ ```bibtex
160
+ @inproceedings{Nguyen2024TurningUT,
161
+ title = {Turning Up the Heat: Min-p Sampling for Creative and Coherent LLM Outputs},
162
+ author = {Minh Nguyen and Andrew Baker and Clement Neo and Allen Roush and Andreas Kirsch and Ravid Shwartz-Ziv},
163
+ year = {2024},
164
+ url = {https://api.semanticscholar.org/CorpusID:270870613}
165
+ }
166
+ ```
167
+
168
+ ```bibtex
169
+ @article{Zhu2024HyperConnections,
170
+ title = {Hyper-Connections},
171
+ author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
172
+ journal = {ArXiv},
173
+ year = {2024},
174
+ volume = {abs/2409.19606},
175
+ url = {https://api.semanticscholar.org/CorpusID:272987528}
176
+ }
177
+ ```
178
+
179
+ ```bibtex
180
+ @article{Zhou2024ValueRL,
181
+ title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
182
+ author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
183
+ journal = {ArXiv},
184
+ year = {2024},
185
+ volume = {abs/2410.17897},
186
+ url = {https://api.semanticscholar.org/CorpusID:273532030}
187
+ }
188
+ ```
189
+
190
+ ```bibtex
191
+ @software{Kyrylov_Accelerated_Scan_2024,
192
+ author = {Kyrylov, Volodymyr},
193
+ doi = {10.5281/zenodo.10600962},
194
+ title = {Accelerated Scan},
195
+ version = {0.1.2},
196
+ year = {2024}
197
+ }
198
+ ```
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=zxknstaI_Uz47Y8WvZ3S7geJ-TNdqKV5Rvj0Jlw8njs,19271
4
+ titans_pytorch/titans.py,sha256=eDTqAIDZjSLd34t8M-dCaqVf_s0wZ9jhVIOfXF7E9ts,21887
5
+ titans_pytorch-0.1.12.dist-info/METADATA,sha256=dL8HpHt6V5gN8p8px7sc2IgJGqXthE7rULKIrRFCwF8,6340
6
+ titans_pytorch-0.1.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.12.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=zxknstaI_Uz47Y8WvZ3S7geJ-TNdqKV5Rvj0Jlw8njs,19271
4
- titans_pytorch/titans.py,sha256=gZvYk1j6aBMp0uE6l1a2GH_4ea9W2uXKytJb3CDPTlk,21162
5
- titans_pytorch-0.1.10.dist-info/METADATA,sha256=o2D4Zau9GLBZmsj2qzq7agWckPnBJhDtIeTj2cMgy7Q,4769
6
- titans_pytorch-0.1.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.10.dist-info/RECORD,,