titans-pytorch 0.1.11__tar.gz → 0.1.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/PKG-INFO +49 -7
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/README.md +48 -6
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/pyproject.toml +1 -1
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/tests/test_titans.py +16 -0
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/titans_pytorch/titans.py +65 -32
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/train_mac.py +2 -0
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/.gitignore +0 -0
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/LICENSE +0 -0
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/data/README.md +0 -0
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/fig1.png +0 -0
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/fig2.png +0 -0
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.1.11 → titans_pytorch-0.1.14}/titans_pytorch/mac_transformer.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.14
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -137,12 +137,13 @@ $ python train_mac.py
|
|
|
137
137
|
```
|
|
138
138
|
|
|
139
139
|
```bibtex
|
|
140
|
-
@
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
140
|
+
@article{Sun2024LearningT,
|
|
141
|
+
title = {Learning to (Learn at Test Time): RNNs with Expressive Hidden States},
|
|
142
|
+
author = {Yu Sun and Xinhao Li and Karan Dalal and Jiarui Xu and Arjun Vikram and Genghan Zhang and Yann Dubois and Xinlei Chen and Xiaolong Wang and Oluwasanmi Koyejo and Tatsunori Hashimoto and Carlos Guestrin},
|
|
143
|
+
journal = {ArXiv},
|
|
144
|
+
year = {2024},
|
|
145
|
+
volume = {abs/2407.04620},
|
|
146
|
+
url = {https://api.semanticscholar.org/CorpusID:271039606}
|
|
146
147
|
}
|
|
147
148
|
```
|
|
148
149
|
|
|
@@ -154,3 +155,44 @@ $ python train_mac.py
|
|
|
154
155
|
url = {https://api.semanticscholar.org/CorpusID:274598177}
|
|
155
156
|
}
|
|
156
157
|
```
|
|
158
|
+
|
|
159
|
+
```bibtex
|
|
160
|
+
@inproceedings{Nguyen2024TurningUT,
|
|
161
|
+
title = {Turning Up the Heat: Min-p Sampling for Creative and Coherent LLM Outputs},
|
|
162
|
+
author = {Minh Nguyen and Andrew Baker and Clement Neo and Allen Roush and Andreas Kirsch and Ravid Shwartz-Ziv},
|
|
163
|
+
year = {2024},
|
|
164
|
+
url = {https://api.semanticscholar.org/CorpusID:270870613}
|
|
165
|
+
}
|
|
166
|
+
```
|
|
167
|
+
|
|
168
|
+
```bibtex
|
|
169
|
+
@article{Zhu2024HyperConnections,
|
|
170
|
+
title = {Hyper-Connections},
|
|
171
|
+
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
|
|
172
|
+
journal = {ArXiv},
|
|
173
|
+
year = {2024},
|
|
174
|
+
volume = {abs/2409.19606},
|
|
175
|
+
url = {https://api.semanticscholar.org/CorpusID:272987528}
|
|
176
|
+
}
|
|
177
|
+
```
|
|
178
|
+
|
|
179
|
+
```bibtex
|
|
180
|
+
@article{Zhou2024ValueRL,
|
|
181
|
+
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
|
|
182
|
+
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
|
|
183
|
+
journal = {ArXiv},
|
|
184
|
+
year = {2024},
|
|
185
|
+
volume = {abs/2410.17897},
|
|
186
|
+
url = {https://api.semanticscholar.org/CorpusID:273532030}
|
|
187
|
+
}
|
|
188
|
+
```
|
|
189
|
+
|
|
190
|
+
```bibtex
|
|
191
|
+
@software{Kyrylov_Accelerated_Scan_2024,
|
|
192
|
+
author = {Kyrylov, Volodymyr},
|
|
193
|
+
doi = {10.5281/zenodo.10600962},
|
|
194
|
+
title = {Accelerated Scan},
|
|
195
|
+
version = {0.1.2},
|
|
196
|
+
year = {2024}
|
|
197
|
+
}
|
|
198
|
+
```
|
|
@@ -83,12 +83,13 @@ $ python train_mac.py
|
|
|
83
83
|
```
|
|
84
84
|
|
|
85
85
|
```bibtex
|
|
86
|
-
@
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
86
|
+
@article{Sun2024LearningT,
|
|
87
|
+
title = {Learning to (Learn at Test Time): RNNs with Expressive Hidden States},
|
|
88
|
+
author = {Yu Sun and Xinhao Li and Karan Dalal and Jiarui Xu and Arjun Vikram and Genghan Zhang and Yann Dubois and Xinlei Chen and Xiaolong Wang and Oluwasanmi Koyejo and Tatsunori Hashimoto and Carlos Guestrin},
|
|
89
|
+
journal = {ArXiv},
|
|
90
|
+
year = {2024},
|
|
91
|
+
volume = {abs/2407.04620},
|
|
92
|
+
url = {https://api.semanticscholar.org/CorpusID:271039606}
|
|
92
93
|
}
|
|
93
94
|
```
|
|
94
95
|
|
|
@@ -100,3 +101,44 @@ $ python train_mac.py
|
|
|
100
101
|
url = {https://api.semanticscholar.org/CorpusID:274598177}
|
|
101
102
|
}
|
|
102
103
|
```
|
|
104
|
+
|
|
105
|
+
```bibtex
|
|
106
|
+
@inproceedings{Nguyen2024TurningUT,
|
|
107
|
+
title = {Turning Up the Heat: Min-p Sampling for Creative and Coherent LLM Outputs},
|
|
108
|
+
author = {Minh Nguyen and Andrew Baker and Clement Neo and Allen Roush and Andreas Kirsch and Ravid Shwartz-Ziv},
|
|
109
|
+
year = {2024},
|
|
110
|
+
url = {https://api.semanticscholar.org/CorpusID:270870613}
|
|
111
|
+
}
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
```bibtex
|
|
115
|
+
@article{Zhu2024HyperConnections,
|
|
116
|
+
title = {Hyper-Connections},
|
|
117
|
+
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
|
|
118
|
+
journal = {ArXiv},
|
|
119
|
+
year = {2024},
|
|
120
|
+
volume = {abs/2409.19606},
|
|
121
|
+
url = {https://api.semanticscholar.org/CorpusID:272987528}
|
|
122
|
+
}
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
```bibtex
|
|
126
|
+
@article{Zhou2024ValueRL,
|
|
127
|
+
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
|
|
128
|
+
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
|
|
129
|
+
journal = {ArXiv},
|
|
130
|
+
year = {2024},
|
|
131
|
+
volume = {abs/2410.17897},
|
|
132
|
+
url = {https://api.semanticscholar.org/CorpusID:273532030}
|
|
133
|
+
}
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
```bibtex
|
|
137
|
+
@software{Kyrylov_Accelerated_Scan_2024,
|
|
138
|
+
author = {Kyrylov, Volodymyr},
|
|
139
|
+
doi = {10.5281/zenodo.10600962},
|
|
140
|
+
title = {Accelerated Scan},
|
|
141
|
+
version = {0.1.2},
|
|
142
|
+
year = {2024}
|
|
143
|
+
}
|
|
144
|
+
```
|
|
@@ -12,6 +12,7 @@ def exists(v):
|
|
|
12
12
|
@pytest.mark.parametrize('silu', (False, True))
|
|
13
13
|
@pytest.mark.parametrize('learned_mem_model_weights', (False, True))
|
|
14
14
|
@pytest.mark.parametrize('attn_pool_chunks', (False, True))
|
|
15
|
+
@pytest.mark.parametrize('momentum', (False, True))
|
|
15
16
|
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
|
16
17
|
@pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
|
|
17
18
|
def test_titans(
|
|
@@ -19,6 +20,7 @@ def test_titans(
|
|
|
19
20
|
silu,
|
|
20
21
|
learned_mem_model_weights,
|
|
21
22
|
attn_pool_chunks,
|
|
23
|
+
momentum,
|
|
22
24
|
max_grad_norm,
|
|
23
25
|
per_parameter_lr_modulation
|
|
24
26
|
):
|
|
@@ -28,6 +30,7 @@ def test_titans(
|
|
|
28
30
|
activation = nn.SiLU() if silu else None,
|
|
29
31
|
attn_pool_chunks = attn_pool_chunks,
|
|
30
32
|
max_grad_norm = max_grad_norm,
|
|
33
|
+
momentum = momentum,
|
|
31
34
|
per_parameter_lr_modulation = per_parameter_lr_modulation,
|
|
32
35
|
learned_mem_model_weights = learned_mem_model_weights
|
|
33
36
|
)
|
|
@@ -66,6 +69,19 @@ def test_retrieve_store_diff_seq():
|
|
|
66
69
|
|
|
67
70
|
assert retrieve_seq.shape == retrieved.shape
|
|
68
71
|
|
|
72
|
+
def test_overriding_chunk_size():
|
|
73
|
+
mem = NeuralMemory(
|
|
74
|
+
dim = 384,
|
|
75
|
+
chunk_size = 64,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
seq = torch.randn(2, 128 * 16, 384)
|
|
79
|
+
store_seq = torch.randn(2, 128 * 8, 384)
|
|
80
|
+
|
|
81
|
+
retrieved = mem(seq, store_seq, chunk_size = 16, store_chunk_size = 8)
|
|
82
|
+
|
|
83
|
+
assert seq.shape == retrieved.shape
|
|
84
|
+
|
|
69
85
|
@pytest.mark.parametrize('seq_len', (1023, 17))
|
|
70
86
|
@pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
|
|
71
87
|
@pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
|
|
@@ -99,7 +99,23 @@ class MultiheadRMSNorm(Module):
|
|
|
99
99
|
def forward(self, x):
|
|
100
100
|
return self.rmsnorm(x) * (self.gamma + 1.)
|
|
101
101
|
|
|
102
|
-
#
|
|
102
|
+
# chunk pooling
|
|
103
|
+
|
|
104
|
+
class AveragePool(Module):
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
chunk_size
|
|
108
|
+
):
|
|
109
|
+
super().__init__()
|
|
110
|
+
self.chunk_size = chunk_size
|
|
111
|
+
|
|
112
|
+
def forward(
|
|
113
|
+
self,
|
|
114
|
+
x,
|
|
115
|
+
chunk_size = None
|
|
116
|
+
):
|
|
117
|
+
chunk_size = default(chunk_size, self.chunk_size)
|
|
118
|
+
return reduce(x, 'b (n c) d -> b n d', 'mean', c = chunk_size)
|
|
103
119
|
|
|
104
120
|
class AttentionPool(Module):
|
|
105
121
|
def __init__(
|
|
@@ -111,7 +127,7 @@ class AttentionPool(Module):
|
|
|
111
127
|
taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
|
|
112
128
|
"""
|
|
113
129
|
super().__init__()
|
|
114
|
-
self.
|
|
130
|
+
self.chunk_size = chunk_size
|
|
115
131
|
self.to_attn_logits = nn.Linear(dim, dim)
|
|
116
132
|
|
|
117
133
|
# default to average pool
|
|
@@ -121,9 +137,13 @@ class AttentionPool(Module):
|
|
|
121
137
|
|
|
122
138
|
def forward(
|
|
123
139
|
self,
|
|
124
|
-
x
|
|
140
|
+
x,
|
|
141
|
+
chunk_size = None
|
|
125
142
|
):
|
|
126
|
-
|
|
143
|
+
chunk_size = default(chunk_size, self.chunk_size)
|
|
144
|
+
|
|
145
|
+
x = rearrange(x, 'b (n c) d -> b n c d', c = chunk_size)
|
|
146
|
+
|
|
127
147
|
attn_logits = self.to_attn_logits(x)
|
|
128
148
|
|
|
129
149
|
attn = attn_logits.softmax(dim = -2)
|
|
@@ -303,6 +323,7 @@ class NeuralMemory(Module):
|
|
|
303
323
|
per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
|
|
304
324
|
max_mem_layer_modulation = 1e1, # max of 10.
|
|
305
325
|
attn_pool_chunks = False,
|
|
326
|
+
momentum = True,
|
|
306
327
|
pre_rmsnorm = True,
|
|
307
328
|
post_rmsnorm = True,
|
|
308
329
|
learned_mem_model_weights = True,
|
|
@@ -394,17 +415,16 @@ class NeuralMemory(Module):
|
|
|
394
415
|
assert not (attn_pool_chunks and chunk_size == 1), '`attn_pool_chunks` cannot be set to True if `chunk_size` is set to 1'
|
|
395
416
|
|
|
396
417
|
if not attn_pool_chunks:
|
|
397
|
-
|
|
418
|
+
self.reduce_to_chunk_rep = AveragePool(chunk_size = chunk_size)
|
|
398
419
|
else:
|
|
399
|
-
|
|
420
|
+
self.reduce_to_chunk_rep = AttentionPool(dim, chunk_size = chunk_size)
|
|
400
421
|
|
|
401
422
|
# learned adaptive learning rate and momentum
|
|
402
423
|
|
|
403
424
|
self.to_momentum = Sequential(
|
|
404
|
-
chunk_reduce_module,
|
|
405
425
|
LinearNoBias(dim, heads),
|
|
406
426
|
Rearrange('b n h -> (b h) n 1')
|
|
407
|
-
)
|
|
427
|
+
) if momentum else None
|
|
408
428
|
|
|
409
429
|
self.to_adaptive_step = Sequential(
|
|
410
430
|
LinearNoBias(dim, heads),
|
|
@@ -419,7 +439,6 @@ class NeuralMemory(Module):
|
|
|
419
439
|
# per layer learning rate modulation
|
|
420
440
|
|
|
421
441
|
self.to_layer_modulation = Sequential(
|
|
422
|
-
chunk_reduce_module,
|
|
423
442
|
LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
|
|
424
443
|
Rearrange('b n (h w) -> w (b h) n', h = heads),
|
|
425
444
|
nn.Sigmoid()
|
|
@@ -434,7 +453,6 @@ class NeuralMemory(Module):
|
|
|
434
453
|
# weight decay factor
|
|
435
454
|
|
|
436
455
|
self.to_decay_factor = Sequential(
|
|
437
|
-
chunk_reduce_module,
|
|
438
456
|
LinearNoBias(dim, heads),
|
|
439
457
|
Rearrange('b n h -> (b h) n 1')
|
|
440
458
|
)
|
|
@@ -445,12 +463,15 @@ class NeuralMemory(Module):
|
|
|
445
463
|
|
|
446
464
|
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
|
447
465
|
|
|
448
|
-
def init_weights_and_momentum(self):
|
|
466
|
+
def init_weights_and_momentum(self, zero_weights = False):
|
|
449
467
|
params = TensorDict(dict(self.memory_model.named_parameters()))
|
|
450
468
|
|
|
451
|
-
init_weights = params
|
|
469
|
+
init_weights = params
|
|
452
470
|
init_momentum = params.clone().zero_()
|
|
453
471
|
|
|
472
|
+
if zero_weights:
|
|
473
|
+
init_weights = params.clone().zero_()
|
|
474
|
+
|
|
454
475
|
return init_weights, init_momentum
|
|
455
476
|
|
|
456
477
|
def init_empty_memory_embed(self, batch, seq_len):
|
|
@@ -460,9 +481,10 @@ class NeuralMemory(Module):
|
|
|
460
481
|
self,
|
|
461
482
|
seq,
|
|
462
483
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
|
|
463
|
-
return_aux_kv_loss = False
|
|
484
|
+
return_aux_kv_loss = False,
|
|
485
|
+
chunk_size = None
|
|
464
486
|
):
|
|
465
|
-
seq_len, chunk_size = seq.shape[-2], self.store_chunk_size
|
|
487
|
+
seq_len, chunk_size = seq.shape[-2], default(chunk_size, self.store_chunk_size)
|
|
466
488
|
|
|
467
489
|
# handle edge case
|
|
468
490
|
|
|
@@ -479,27 +501,28 @@ class NeuralMemory(Module):
|
|
|
479
501
|
|
|
480
502
|
seq = seq[:, :round_down_seq_len]
|
|
481
503
|
|
|
482
|
-
#
|
|
483
|
-
|
|
484
|
-
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
|
504
|
+
# get the weights of the memory network
|
|
485
505
|
|
|
486
506
|
past_state = tuple(TensorDict(d) for d in past_state)
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
curr_weights = curr_weights + past_weights
|
|
507
|
+
curr_weights, past_momentum = past_state
|
|
490
508
|
|
|
491
|
-
#
|
|
509
|
+
# derive learned hparams for optimization of memory network
|
|
492
510
|
|
|
493
511
|
adaptive_lr = self.to_adaptive_step(seq)
|
|
494
512
|
adaptive_lr = self.adaptive_step_transform(adaptive_lr)
|
|
495
513
|
|
|
496
|
-
|
|
497
|
-
|
|
514
|
+
chunked_seq = self.reduce_to_chunk_rep(seq, chunk_size = chunk_size)
|
|
515
|
+
|
|
516
|
+
decay_factor = self.to_decay_factor(chunked_seq).sigmoid()
|
|
498
517
|
|
|
499
518
|
need_layer_lr_mod = exists(self.to_layer_modulation)
|
|
519
|
+
has_momentum = exists(self.to_momentum)
|
|
520
|
+
|
|
521
|
+
if has_momentum:
|
|
522
|
+
adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
|
|
500
523
|
|
|
501
524
|
if need_layer_lr_mod:
|
|
502
|
-
layer_lr_mod = self.to_layer_modulation(
|
|
525
|
+
layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation
|
|
503
526
|
|
|
504
527
|
# keys and values
|
|
505
528
|
|
|
@@ -575,23 +598,29 @@ class NeuralMemory(Module):
|
|
|
575
598
|
|
|
576
599
|
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
|
577
600
|
|
|
578
|
-
next_momentum = TensorDict()
|
|
601
|
+
next_momentum = TensorDict() if has_momentum else None
|
|
579
602
|
updates = TensorDict()
|
|
580
603
|
|
|
581
604
|
for param_name, surprise in surprises.items():
|
|
582
605
|
|
|
583
606
|
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
|
|
584
607
|
|
|
608
|
+
update = surprise
|
|
609
|
+
|
|
585
610
|
# derive momentum with associative scan - eq (10)
|
|
586
611
|
|
|
587
|
-
|
|
612
|
+
if has_momentum:
|
|
613
|
+
update = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
|
|
614
|
+
momentum = update
|
|
588
615
|
|
|
589
616
|
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
|
590
617
|
|
|
591
|
-
update = scan_fn(1. - decay_factor,
|
|
618
|
+
update = scan_fn(1. - decay_factor, update)
|
|
592
619
|
|
|
593
620
|
updates[param_name] = inverse_pack(update)
|
|
594
|
-
|
|
621
|
+
|
|
622
|
+
if has_momentum:
|
|
623
|
+
next_momentum[param_name] = inverse_pack(momentum)
|
|
595
624
|
|
|
596
625
|
# compute the next weight per batch
|
|
597
626
|
|
|
@@ -606,8 +635,9 @@ class NeuralMemory(Module):
|
|
|
606
635
|
self,
|
|
607
636
|
seq,
|
|
608
637
|
past_weights: dict[str, Tensor] | None = None,
|
|
638
|
+
chunk_size = None
|
|
609
639
|
):
|
|
610
|
-
chunk_size = self.retrieve_chunk_size
|
|
640
|
+
chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
|
611
641
|
batch, seq_len = seq.shape[:2]
|
|
612
642
|
|
|
613
643
|
seq = self.retrieve_norm(seq)
|
|
@@ -680,7 +710,9 @@ class NeuralMemory(Module):
|
|
|
680
710
|
seq,
|
|
681
711
|
store_seq = None,
|
|
682
712
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
|
683
|
-
return_aux_kv_loss = False
|
|
713
|
+
return_aux_kv_loss = False,
|
|
714
|
+
chunk_size = None,
|
|
715
|
+
store_chunk_size = None
|
|
684
716
|
):
|
|
685
717
|
batch, seq_len = seq.shape[:2]
|
|
686
718
|
|
|
@@ -699,12 +731,13 @@ class NeuralMemory(Module):
|
|
|
699
731
|
past_state = self.init_weights_and_momentum()
|
|
700
732
|
|
|
701
733
|
store_seq = default(store_seq, seq)
|
|
734
|
+
store_chunk_size = default(store_chunk_size, chunk_size)
|
|
702
735
|
|
|
703
|
-
updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
|
|
736
|
+
updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
|
|
704
737
|
|
|
705
738
|
past_weights, _ = past_state
|
|
706
739
|
|
|
707
|
-
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
|
740
|
+
retrieved = self.retrieve_memories(seq, past_weights + updates, chunk_size = chunk_size)
|
|
708
741
|
|
|
709
742
|
if not return_aux_kv_loss:
|
|
710
743
|
return retrieved
|
|
@@ -31,6 +31,7 @@ NUM_PERSIST_MEM = 4
|
|
|
31
31
|
NUM_LONGTERM_MEM = 4
|
|
32
32
|
NEURAL_MEM_LAYERS = (2, 4)
|
|
33
33
|
NEURAL_MEM_GATE_ATTN_OUTPUT = True
|
|
34
|
+
NEURAL_MEM_MOMENTUM = True
|
|
34
35
|
WINDOW_SIZE = 32
|
|
35
36
|
NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
|
|
36
37
|
SLIDING_WINDOWS = True
|
|
@@ -88,6 +89,7 @@ model = MemoryAsContextTransformer(
|
|
|
88
89
|
dim_head = 64,
|
|
89
90
|
heads = 4,
|
|
90
91
|
attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
|
|
92
|
+
momentum = NEURAL_MEM_MOMENTUM,
|
|
91
93
|
use_accelerated_scan = USE_ACCELERATED_SCAN,
|
|
92
94
|
learned_mem_model_weights = LEARNED_MEM_MODEL_WEIGHTS,
|
|
93
95
|
default_model_kwargs = dict(
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|