titans-pytorch 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +43 -19
- {titans_pytorch-0.1.11.dist-info → titans_pytorch-0.1.12.dist-info}/METADATA +49 -7
- titans_pytorch-0.1.12.dist-info/RECORD +8 -0
- titans_pytorch-0.1.11.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.11.dist-info → titans_pytorch-0.1.12.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.11.dist-info → titans_pytorch-0.1.12.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
|
@@ -99,7 +99,23 @@ class MultiheadRMSNorm(Module):
|
|
|
99
99
|
def forward(self, x):
|
|
100
100
|
return self.rmsnorm(x) * (self.gamma + 1.)
|
|
101
101
|
|
|
102
|
-
#
|
|
102
|
+
# chunk pooling
|
|
103
|
+
|
|
104
|
+
class AveragePool(Module):
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
chunk_size
|
|
108
|
+
):
|
|
109
|
+
super().__init__()
|
|
110
|
+
self.chunk_size = chunk_size
|
|
111
|
+
|
|
112
|
+
def forward(
|
|
113
|
+
self,
|
|
114
|
+
x,
|
|
115
|
+
chunk_size = None
|
|
116
|
+
):
|
|
117
|
+
chunk_size = default(chunk_size, self.chunk_size)
|
|
118
|
+
return reduce(x, 'b (n c) d -> b n d', 'mean', c = chunk_size)
|
|
103
119
|
|
|
104
120
|
class AttentionPool(Module):
|
|
105
121
|
def __init__(
|
|
@@ -111,7 +127,7 @@ class AttentionPool(Module):
|
|
|
111
127
|
taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
|
|
112
128
|
"""
|
|
113
129
|
super().__init__()
|
|
114
|
-
self.
|
|
130
|
+
self.chunk_size = chunk_size
|
|
115
131
|
self.to_attn_logits = nn.Linear(dim, dim)
|
|
116
132
|
|
|
117
133
|
# default to average pool
|
|
@@ -121,9 +137,13 @@ class AttentionPool(Module):
|
|
|
121
137
|
|
|
122
138
|
def forward(
|
|
123
139
|
self,
|
|
124
|
-
x
|
|
140
|
+
x,
|
|
141
|
+
chunk_size = None
|
|
125
142
|
):
|
|
126
|
-
|
|
143
|
+
chunk_size = default(chunk_size, self.chunk_size)
|
|
144
|
+
|
|
145
|
+
x = rearrange(x, 'b (n c) d -> b n c d', c = chunk_size)
|
|
146
|
+
|
|
127
147
|
attn_logits = self.to_attn_logits(x)
|
|
128
148
|
|
|
129
149
|
attn = attn_logits.softmax(dim = -2)
|
|
@@ -394,14 +414,13 @@ class NeuralMemory(Module):
|
|
|
394
414
|
assert not (attn_pool_chunks and chunk_size == 1), '`attn_pool_chunks` cannot be set to True if `chunk_size` is set to 1'
|
|
395
415
|
|
|
396
416
|
if not attn_pool_chunks:
|
|
397
|
-
|
|
417
|
+
self.reduce_to_chunk_rep = AveragePool(chunk_size = chunk_size)
|
|
398
418
|
else:
|
|
399
|
-
|
|
419
|
+
self.reduce_to_chunk_rep = AttentionPool(dim, chunk_size = chunk_size)
|
|
400
420
|
|
|
401
421
|
# learned adaptive learning rate and momentum
|
|
402
422
|
|
|
403
423
|
self.to_momentum = Sequential(
|
|
404
|
-
chunk_reduce_module,
|
|
405
424
|
LinearNoBias(dim, heads),
|
|
406
425
|
Rearrange('b n h -> (b h) n 1')
|
|
407
426
|
)
|
|
@@ -419,7 +438,6 @@ class NeuralMemory(Module):
|
|
|
419
438
|
# per layer learning rate modulation
|
|
420
439
|
|
|
421
440
|
self.to_layer_modulation = Sequential(
|
|
422
|
-
chunk_reduce_module,
|
|
423
441
|
LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
|
|
424
442
|
Rearrange('b n (h w) -> w (b h) n', h = heads),
|
|
425
443
|
nn.Sigmoid()
|
|
@@ -434,7 +452,6 @@ class NeuralMemory(Module):
|
|
|
434
452
|
# weight decay factor
|
|
435
453
|
|
|
436
454
|
self.to_decay_factor = Sequential(
|
|
437
|
-
chunk_reduce_module,
|
|
438
455
|
LinearNoBias(dim, heads),
|
|
439
456
|
Rearrange('b n h -> (b h) n 1')
|
|
440
457
|
)
|
|
@@ -460,9 +477,10 @@ class NeuralMemory(Module):
|
|
|
460
477
|
self,
|
|
461
478
|
seq,
|
|
462
479
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
|
|
463
|
-
return_aux_kv_loss = False
|
|
480
|
+
return_aux_kv_loss = False,
|
|
481
|
+
chunk_size = None
|
|
464
482
|
):
|
|
465
|
-
seq_len, chunk_size = seq.shape[-2], self.store_chunk_size
|
|
483
|
+
seq_len, chunk_size = seq.shape[-2], default(chunk_size, self.store_chunk_size)
|
|
466
484
|
|
|
467
485
|
# handle edge case
|
|
468
486
|
|
|
@@ -488,18 +506,20 @@ class NeuralMemory(Module):
|
|
|
488
506
|
|
|
489
507
|
curr_weights = curr_weights + past_weights
|
|
490
508
|
|
|
491
|
-
#
|
|
509
|
+
# derive learned hparams for optimization of memory network
|
|
492
510
|
|
|
493
511
|
adaptive_lr = self.to_adaptive_step(seq)
|
|
494
512
|
adaptive_lr = self.adaptive_step_transform(adaptive_lr)
|
|
495
513
|
|
|
496
|
-
|
|
497
|
-
|
|
514
|
+
chunked_seq = self.reduce_to_chunk_rep(seq, chunk_size = chunk_size)
|
|
515
|
+
|
|
516
|
+
adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
|
|
517
|
+
decay_factor = self.to_decay_factor(chunked_seq).sigmoid()
|
|
498
518
|
|
|
499
519
|
need_layer_lr_mod = exists(self.to_layer_modulation)
|
|
500
520
|
|
|
501
521
|
if need_layer_lr_mod:
|
|
502
|
-
layer_lr_mod = self.to_layer_modulation(
|
|
522
|
+
layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation
|
|
503
523
|
|
|
504
524
|
# keys and values
|
|
505
525
|
|
|
@@ -606,8 +626,9 @@ class NeuralMemory(Module):
|
|
|
606
626
|
self,
|
|
607
627
|
seq,
|
|
608
628
|
past_weights: dict[str, Tensor] | None = None,
|
|
629
|
+
chunk_size = None
|
|
609
630
|
):
|
|
610
|
-
chunk_size = self.retrieve_chunk_size
|
|
631
|
+
chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
|
611
632
|
batch, seq_len = seq.shape[:2]
|
|
612
633
|
|
|
613
634
|
seq = self.retrieve_norm(seq)
|
|
@@ -680,7 +701,9 @@ class NeuralMemory(Module):
|
|
|
680
701
|
seq,
|
|
681
702
|
store_seq = None,
|
|
682
703
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
|
683
|
-
return_aux_kv_loss = False
|
|
704
|
+
return_aux_kv_loss = False,
|
|
705
|
+
chunk_size = None,
|
|
706
|
+
store_chunk_size = None
|
|
684
707
|
):
|
|
685
708
|
batch, seq_len = seq.shape[:2]
|
|
686
709
|
|
|
@@ -699,12 +722,13 @@ class NeuralMemory(Module):
|
|
|
699
722
|
past_state = self.init_weights_and_momentum()
|
|
700
723
|
|
|
701
724
|
store_seq = default(store_seq, seq)
|
|
725
|
+
store_chunk_size = default(store_chunk_size, chunk_size)
|
|
702
726
|
|
|
703
|
-
updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
|
|
727
|
+
updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
|
|
704
728
|
|
|
705
729
|
past_weights, _ = past_state
|
|
706
730
|
|
|
707
|
-
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
|
731
|
+
retrieved = self.retrieve_memories(seq, past_weights + updates, chunk_size = chunk_size)
|
|
708
732
|
|
|
709
733
|
if not return_aux_kv_loss:
|
|
710
734
|
return retrieved
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.12
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -137,12 +137,13 @@ $ python train_mac.py
|
|
|
137
137
|
```
|
|
138
138
|
|
|
139
139
|
```bibtex
|
|
140
|
-
@
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
140
|
+
@article{Sun2024LearningT,
|
|
141
|
+
title = {Learning to (Learn at Test Time): RNNs with Expressive Hidden States},
|
|
142
|
+
author = {Yu Sun and Xinhao Li and Karan Dalal and Jiarui Xu and Arjun Vikram and Genghan Zhang and Yann Dubois and Xinlei Chen and Xiaolong Wang and Oluwasanmi Koyejo and Tatsunori Hashimoto and Carlos Guestrin},
|
|
143
|
+
journal = {ArXiv},
|
|
144
|
+
year = {2024},
|
|
145
|
+
volume = {abs/2407.04620},
|
|
146
|
+
url = {https://api.semanticscholar.org/CorpusID:271039606}
|
|
146
147
|
}
|
|
147
148
|
```
|
|
148
149
|
|
|
@@ -154,3 +155,44 @@ $ python train_mac.py
|
|
|
154
155
|
url = {https://api.semanticscholar.org/CorpusID:274598177}
|
|
155
156
|
}
|
|
156
157
|
```
|
|
158
|
+
|
|
159
|
+
```bibtex
|
|
160
|
+
@inproceedings{Nguyen2024TurningUT,
|
|
161
|
+
title = {Turning Up the Heat: Min-p Sampling for Creative and Coherent LLM Outputs},
|
|
162
|
+
author = {Minh Nguyen and Andrew Baker and Clement Neo and Allen Roush and Andreas Kirsch and Ravid Shwartz-Ziv},
|
|
163
|
+
year = {2024},
|
|
164
|
+
url = {https://api.semanticscholar.org/CorpusID:270870613}
|
|
165
|
+
}
|
|
166
|
+
```
|
|
167
|
+
|
|
168
|
+
```bibtex
|
|
169
|
+
@article{Zhu2024HyperConnections,
|
|
170
|
+
title = {Hyper-Connections},
|
|
171
|
+
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
|
|
172
|
+
journal = {ArXiv},
|
|
173
|
+
year = {2024},
|
|
174
|
+
volume = {abs/2409.19606},
|
|
175
|
+
url = {https://api.semanticscholar.org/CorpusID:272987528}
|
|
176
|
+
}
|
|
177
|
+
```
|
|
178
|
+
|
|
179
|
+
```bibtex
|
|
180
|
+
@article{Zhou2024ValueRL,
|
|
181
|
+
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
|
|
182
|
+
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
|
|
183
|
+
journal = {ArXiv},
|
|
184
|
+
year = {2024},
|
|
185
|
+
volume = {abs/2410.17897},
|
|
186
|
+
url = {https://api.semanticscholar.org/CorpusID:273532030}
|
|
187
|
+
}
|
|
188
|
+
```
|
|
189
|
+
|
|
190
|
+
```bibtex
|
|
191
|
+
@software{Kyrylov_Accelerated_Scan_2024,
|
|
192
|
+
author = {Kyrylov, Volodymyr},
|
|
193
|
+
doi = {10.5281/zenodo.10600962},
|
|
194
|
+
title = {Accelerated Scan},
|
|
195
|
+
version = {0.1.2},
|
|
196
|
+
year = {2024}
|
|
197
|
+
}
|
|
198
|
+
```
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=zxknstaI_Uz47Y8WvZ3S7geJ-TNdqKV5Rvj0Jlw8njs,19271
|
|
4
|
+
titans_pytorch/titans.py,sha256=eDTqAIDZjSLd34t8M-dCaqVf_s0wZ9jhVIOfXF7E9ts,21887
|
|
5
|
+
titans_pytorch-0.1.12.dist-info/METADATA,sha256=dL8HpHt6V5gN8p8px7sc2IgJGqXthE7rULKIrRFCwF8,6340
|
|
6
|
+
titans_pytorch-0.1.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.1.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.1.12.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=zxknstaI_Uz47Y8WvZ3S7geJ-TNdqKV5Rvj0Jlw8njs,19271
|
|
4
|
-
titans_pytorch/titans.py,sha256=gZvYk1j6aBMp0uE6l1a2GH_4ea9W2uXKytJb3CDPTlk,21162
|
|
5
|
-
titans_pytorch-0.1.11.dist-info/METADATA,sha256=6u7nYbl0juOqG2b6b7jSjbKJCQEtPEUpiZVe3JWSn1A,4769
|
|
6
|
-
titans_pytorch-0.1.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.1.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.1.11.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|