titans-pytorch 0.2.9__tar.gz → 0.2.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/PKG-INFO +1 -1
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/pyproject.toml +1 -1
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/tests/test_titans.py +31 -68
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/titans_pytorch/mac_transformer.py +8 -48
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/titans_pytorch/neural_memory.py +183 -128
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/train_mac.py +8 -12
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/.gitignore +0 -0
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/LICENSE +0 -0
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/README.md +0 -0
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/data/README.md +0 -0
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/data/enwik8.gz +0 -0
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/fig1.png +0 -0
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/fig2.png +0 -0
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.2.9 → titans_pytorch-0.2.11}/titans_pytorch/memory_models.py +0 -0
@@ -73,101 +73,62 @@ def test_titans_attn_memory():
|
|
73
73
|
|
74
74
|
assert seq.shape == retrieved.shape
|
75
75
|
|
76
|
-
def
|
77
|
-
mem
|
76
|
+
def test_neural_mem_chaining_chunks():
|
77
|
+
mem = NeuralMemory(
|
78
78
|
dim = 384,
|
79
|
-
|
79
|
+
dim_head = 64,
|
80
|
+
heads = 2,
|
81
|
+
chunk_size = 16
|
80
82
|
)
|
81
83
|
|
82
|
-
|
83
|
-
store_seq = torch.randn(2, 64 * 32, 384)
|
84
|
+
seq = torch.randn(2, 48, 384)
|
84
85
|
|
85
|
-
|
86
|
+
parallel_retrieved, state = mem(seq)
|
86
87
|
|
87
|
-
|
88
|
+
seq_first, seq_second, seq_third = seq.split(16, dim = 1)
|
88
89
|
|
89
|
-
|
90
|
-
|
90
|
+
first_retrieved, state = mem(seq_first)
|
91
|
+
second_retrieved, state = mem(seq_second, state = state)
|
92
|
+
third_retrieved, state = mem(seq_third, state = state)
|
91
93
|
|
92
|
-
|
94
|
+
assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved, third_retrieved), dim = 1), atol = 1e-5)
|
93
95
|
|
96
|
+
def test_neural_mem_chaining_with_batch_size():
|
94
97
|
mem = NeuralMemory(
|
95
98
|
dim = 384,
|
96
99
|
dim_head = 64,
|
97
100
|
heads = 2,
|
98
|
-
chunk_size =
|
99
|
-
|
100
|
-
)
|
101
|
-
|
102
|
-
mem2 = NeuralMemory(
|
103
|
-
dim = 384,
|
104
|
-
dim_head = 64,
|
105
|
-
heads = 2,
|
106
|
-
chunk_size = 2,
|
107
|
-
model = mlp
|
108
|
-
)
|
109
|
-
|
110
|
-
mem3 = NeuralMemory(
|
111
|
-
dim = 384,
|
112
|
-
dim_head = 64,
|
113
|
-
heads = 2,
|
114
|
-
chunk_size = 2,
|
115
|
-
model = mlp
|
116
|
-
)
|
117
|
-
|
118
|
-
seq = torch.randn(2, 128, 384)
|
119
|
-
|
120
|
-
seq, cache = mem(seq)
|
121
|
-
seq, cache2 = mem2(seq, prev_layer_updates = cache.updates)
|
122
|
-
seq, cache3 = mem3(seq, prev_layer_updates = cache2.updates)
|
123
|
-
|
124
|
-
def test_mac_with_weight_tied_neural_mem():
|
125
|
-
from titans_pytorch import MemoryMLP, MemoryAsContextTransformer
|
126
|
-
|
127
|
-
transformer = MemoryAsContextTransformer(
|
128
|
-
num_tokens = 256,
|
129
|
-
dim = 256,
|
130
|
-
depth = 2,
|
131
|
-
segment_len = 2,
|
132
|
-
num_persist_mem_tokens = 0,
|
133
|
-
num_longterm_mem_tokens = 2,
|
134
|
-
neural_memory_segment_len = 2,
|
135
|
-
sliding_window_attn = True,
|
136
|
-
neural_memory_layers = (1, 2),
|
137
|
-
neural_memory_model = MemoryMLP(256, depth = 1),
|
138
|
-
num_residual_streams = 4,
|
139
|
-
weight_tie_memory_model = True,
|
140
|
-
neural_mem_gate_attn_output = True,
|
101
|
+
chunk_size = 16,
|
102
|
+
batch_size = 64
|
141
103
|
)
|
142
104
|
|
105
|
+
seq = torch.randn(2, 112, 384)
|
143
106
|
|
144
|
-
|
145
|
-
logits = transformer(ids)
|
107
|
+
parallel_retrieved, state = mem(seq)
|
146
108
|
|
147
|
-
|
109
|
+
seq_first, seq_second, seq_third = seq[:, :16], seq[:, 16:64], seq[:, 64:]
|
148
110
|
|
149
|
-
|
150
|
-
mem =
|
151
|
-
|
152
|
-
chunk_size = 64,
|
153
|
-
)
|
154
|
-
|
155
|
-
seq = torch.randn(2, 128 * 16, 384)
|
156
|
-
store_seq = torch.randn(2, 128 * 8, 384)
|
111
|
+
first_retrieved, state = mem(seq_first)
|
112
|
+
second_retrieved, state = mem(seq_second, state = state)
|
113
|
+
third_retrieved, state = mem(seq_third, state = state)
|
157
114
|
|
158
|
-
|
115
|
+
parallel_part_retrieved = torch.cat((first_retrieved, second_retrieved, third_retrieved), dim = 1)
|
159
116
|
|
160
|
-
assert
|
117
|
+
assert torch.allclose(parallel_retrieved, parallel_part_retrieved, atol = 1e-5)
|
161
118
|
|
162
119
|
@pytest.mark.parametrize('seq_len', (1023, 17))
|
163
120
|
@pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
|
164
121
|
@pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
|
165
122
|
@pytest.mark.parametrize('neural_mem_gate_attn_output', (False, True))
|
123
|
+
@pytest.mark.parametrize('neural_mem_segment_len', (8, 16))
|
124
|
+
@pytest.mark.parametrize('neural_mem_batch_size', (None, 64))
|
166
125
|
def test_mac(
|
167
126
|
seq_len,
|
168
127
|
num_persist_mem_tokens,
|
169
128
|
num_longterm_mem_tokens,
|
170
|
-
neural_mem_gate_attn_output
|
129
|
+
neural_mem_gate_attn_output,
|
130
|
+
neural_mem_segment_len,
|
131
|
+
neural_mem_batch_size
|
171
132
|
):
|
172
133
|
transformer = MemoryAsContextTransformer(
|
173
134
|
num_tokens = 256,
|
@@ -176,7 +137,9 @@ def test_mac(
|
|
176
137
|
num_persist_mem_tokens = num_persist_mem_tokens,
|
177
138
|
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
178
139
|
segment_len = 128,
|
179
|
-
neural_mem_gate_attn_output = neural_mem_gate_attn_output
|
140
|
+
neural_mem_gate_attn_output = neural_mem_gate_attn_output,
|
141
|
+
neural_memory_segment_len = neural_mem_segment_len,
|
142
|
+
neural_memory_batch_size = neural_mem_batch_size,
|
180
143
|
)
|
181
144
|
|
182
145
|
x = torch.randint(0, 256, (1, seq_len))
|
@@ -481,6 +481,7 @@ class MemoryAsContextTransformer(Module):
|
|
481
481
|
neural_memory_add_value_residual = False,
|
482
482
|
num_longterm_mem_tokens = 0,
|
483
483
|
num_persist_mem_tokens = 0,
|
484
|
+
neural_memory_batch_size = None,
|
484
485
|
dim_head = 64,
|
485
486
|
heads = 8,
|
486
487
|
ff_mult = 4,
|
@@ -488,11 +489,8 @@ class MemoryAsContextTransformer(Module):
|
|
488
489
|
neural_memory_model: Module | None = None,
|
489
490
|
neural_memory_kwargs: dict = dict(),
|
490
491
|
neural_memory_layers: tuple[int, ...] | None = None,
|
491
|
-
aux_kv_recon_loss_weight = 1.,
|
492
492
|
use_flex_attn = False,
|
493
493
|
sliding_window_attn = False,
|
494
|
-
weight_tie_memory_model = False,
|
495
|
-
prev_neural_mem_update_for_weights = None
|
496
494
|
):
|
497
495
|
super().__init__()
|
498
496
|
|
@@ -526,16 +524,6 @@ class MemoryAsContextTransformer(Module):
|
|
526
524
|
|
527
525
|
neural_memory_layers = default(neural_memory_layers, layers)
|
528
526
|
|
529
|
-
# weight tying neural memory model
|
530
|
-
|
531
|
-
maybe_copy = deepcopy if not weight_tie_memory_model else identity
|
532
|
-
|
533
|
-
if weight_tie_memory_model:
|
534
|
-
assert exists(neural_memory_model), '`neural_memory_model` must be explicitly set'
|
535
|
-
|
536
|
-
self.weight_tie_memory_model = weight_tie_memory_model
|
537
|
-
self.prev_neural_mem_update_for_weights = default(prev_neural_mem_update_for_weights, weight_tie_memory_model)
|
538
|
-
|
539
527
|
# mem, attn, and feedforward layers
|
540
528
|
|
541
529
|
for layer in layers:
|
@@ -564,7 +552,8 @@ class MemoryAsContextTransformer(Module):
|
|
564
552
|
mem = NeuralMemory(
|
565
553
|
dim = dim,
|
566
554
|
chunk_size = self.neural_memory_segment_len,
|
567
|
-
|
555
|
+
batch_size = neural_memory_batch_size,
|
556
|
+
model = deepcopy(neural_memory_model),
|
568
557
|
**neural_memory_kwargs
|
569
558
|
)
|
570
559
|
|
@@ -585,10 +574,7 @@ class MemoryAsContextTransformer(Module):
|
|
585
574
|
|
586
575
|
self.gate_attn_output = neural_mem_gate_attn_output
|
587
576
|
|
588
|
-
#
|
589
|
-
|
590
|
-
self.has_aux_kv_recon_loss = aux_kv_recon_loss_weight > 0.
|
591
|
-
self.aux_kv_recon_loss_weight = aux_kv_recon_loss_weight
|
577
|
+
# zero for maybe aux loss + device
|
592
578
|
|
593
579
|
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
594
580
|
|
@@ -696,7 +682,7 @@ class MemoryAsContextTransformer(Module):
|
|
696
682
|
|
697
683
|
# math
|
698
684
|
|
699
|
-
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size
|
685
|
+
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
|
700
686
|
|
701
687
|
seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
|
702
688
|
|
@@ -749,18 +735,10 @@ class MemoryAsContextTransformer(Module):
|
|
749
735
|
next_kv_caches = []
|
750
736
|
next_neural_mem_caches = []
|
751
737
|
|
752
|
-
# weight tied neural memory
|
753
|
-
|
754
|
-
neural_memory_updates = None
|
755
|
-
|
756
738
|
# value residual
|
757
739
|
|
758
740
|
value_residual = None
|
759
741
|
|
760
|
-
# aux losses
|
761
|
-
|
762
|
-
kv_recon_losses = self.zero
|
763
|
-
|
764
742
|
# when inferencing, only do one token at a time
|
765
743
|
|
766
744
|
if is_inferencing:
|
@@ -784,24 +762,16 @@ class MemoryAsContextTransformer(Module):
|
|
784
762
|
mem_input, add_residual = mem_hyper_conn(x)
|
785
763
|
|
786
764
|
if not is_inferencing:
|
787
|
-
|
788
|
-
mem_input
|
789
|
-
return_aux_kv_loss = True,
|
790
|
-
prev_layer_updates = neural_memory_updates
|
765
|
+
retrieved, next_neural_mem_cache = mem(
|
766
|
+
mem_input
|
791
767
|
)
|
792
768
|
|
793
|
-
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
794
|
-
|
795
769
|
else:
|
796
770
|
(retrieved, next_neural_mem_cache) = mem.forward_inference(
|
797
771
|
mem_input,
|
798
772
|
state = next(neural_mem_caches, None),
|
799
|
-
prev_layer_updates = neural_memory_updates
|
800
773
|
)
|
801
774
|
|
802
|
-
if prev_neural_mem_update_for_weights:
|
803
|
-
neural_memory_updates = next_neural_mem_cache.updates
|
804
|
-
|
805
775
|
if self.gate_attn_output:
|
806
776
|
attn_out_gates = retrieved.sigmoid()
|
807
777
|
else:
|
@@ -883,14 +853,4 @@ class MemoryAsContextTransformer(Module):
|
|
883
853
|
|
884
854
|
return logits, next_cache
|
885
855
|
|
886
|
-
|
887
|
-
|
888
|
-
losses = ar_loss
|
889
|
-
|
890
|
-
if self.has_aux_kv_recon_loss:
|
891
|
-
losses = losses + kv_recon_losses * self.aux_kv_recon_loss_weight
|
892
|
-
|
893
|
-
if not return_loss_breakdown:
|
894
|
-
return losses
|
895
|
-
|
896
|
-
return losses, (ar_loss, kv_recon_losses)
|
856
|
+
return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
@@ -6,7 +6,7 @@ from functools import partial
|
|
6
6
|
from collections import namedtuple
|
7
7
|
|
8
8
|
import torch
|
9
|
-
from torch import nn, cat, Tensor
|
9
|
+
from torch import nn, cat, tensor, Tensor
|
10
10
|
import torch.nn.functional as F
|
11
11
|
from torch.nn import Linear, Module, Parameter, ParameterList
|
12
12
|
from torch.func import functional_call, vmap, grad
|
@@ -38,7 +38,13 @@ w - num memory network weight parameters
|
|
38
38
|
|
39
39
|
LinearNoBias = partial(Linear, bias = False)
|
40
40
|
|
41
|
-
NeuralMemCache = namedtuple('NeuralMemCache', [
|
41
|
+
NeuralMemCache = namedtuple('NeuralMemCache', [
|
42
|
+
'seq_index',
|
43
|
+
'weights',
|
44
|
+
'cache_store_segment',
|
45
|
+
'states',
|
46
|
+
'updates',
|
47
|
+
])
|
42
48
|
|
43
49
|
# functions
|
44
50
|
|
@@ -57,6 +63,9 @@ def identity(t):
|
|
57
63
|
def xnor(x, y):
|
58
64
|
return not (x ^ y)
|
59
65
|
|
66
|
+
def divisible_by(num, den):
|
67
|
+
return (num % den) == 0
|
68
|
+
|
60
69
|
def safe_cat(inputs, dim = -2):
|
61
70
|
inputs = tuple(filter(exists, inputs))
|
62
71
|
|
@@ -67,9 +76,18 @@ def safe_cat(inputs, dim = -2):
|
|
67
76
|
|
68
77
|
return cat(inputs, dim = dim)
|
69
78
|
|
79
|
+
def is_empty_tensor(t):
|
80
|
+
return t.numel() == 0
|
81
|
+
|
70
82
|
def dict_get_shape(td):
|
71
83
|
return {k: v.shape for k, v in td.items()}
|
72
84
|
|
85
|
+
def rearrange_dict_values(td, pattern, **kwargs):
|
86
|
+
return td.apply(lambda t: rearrange(t, pattern, **kwargs))
|
87
|
+
|
88
|
+
def repeat_dict_values(td, pattern, **kwargs):
|
89
|
+
return td.apply(lambda t: repeat(t, pattern, **kwargs))
|
90
|
+
|
73
91
|
def pair(v):
|
74
92
|
return (v, v) if not isinstance(v, tuple) else v
|
75
93
|
|
@@ -106,6 +124,9 @@ def softclamp_max(t, max_value):
|
|
106
124
|
return ((t / half_max_value).tanh() * half_max_value) + half_max_value
|
107
125
|
|
108
126
|
def softclamp_grad_norm(t, max_value):
|
127
|
+
if is_empty_tensor(t):
|
128
|
+
return t
|
129
|
+
|
109
130
|
t, inverse = pack_one_with_inverse(t, 'bn *')
|
110
131
|
|
111
132
|
norm = t.norm(dim = -1, keepdim = True)
|
@@ -195,6 +216,12 @@ class AssocScan(Module):
|
|
195
216
|
):
|
196
217
|
remove_prev = default(remove_prev, exists(prev))
|
197
218
|
|
219
|
+
inputs, inverse_pack_weight_shape = pack_one_with_inverse(inputs, 'b n *')
|
220
|
+
gates, _ = pack_one_with_inverse(gates, 'b n *')
|
221
|
+
|
222
|
+
if exists(prev):
|
223
|
+
prev, _ = pack_one_with_inverse(prev, 'b *')
|
224
|
+
|
198
225
|
if exists(prev):
|
199
226
|
inputs, _ = pack([prev, inputs], 'b * d')
|
200
227
|
gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
|
@@ -205,7 +232,7 @@ class AssocScan(Module):
|
|
205
232
|
if remove_prev:
|
206
233
|
out = out[:, 1:]
|
207
234
|
|
208
|
-
return out
|
235
|
+
return inverse_pack_weight_shape(out)
|
209
236
|
|
210
237
|
from accelerated_scan.triton import scan as triton_scan
|
211
238
|
from accelerated_scan.warp import scan as warp_scan
|
@@ -226,6 +253,7 @@ class AssocScan(Module):
|
|
226
253
|
|
227
254
|
outputs = outputs[..., :seq_len]
|
228
255
|
outputs = rearrange(outputs, 'b d n -> b n d')
|
256
|
+
|
229
257
|
return outputs
|
230
258
|
|
231
259
|
out = accelerate_scan_fn(gates, inputs)
|
@@ -233,7 +261,7 @@ class AssocScan(Module):
|
|
233
261
|
if remove_prev:
|
234
262
|
out = out[:, 1:]
|
235
263
|
|
236
|
-
return out
|
264
|
+
return inverse_pack_weight_shape(out)
|
237
265
|
|
238
266
|
# main neural memory
|
239
267
|
|
@@ -248,12 +276,13 @@ class NeuralMemory(Module):
|
|
248
276
|
self,
|
249
277
|
dim,
|
250
278
|
chunk_size: int | tuple[int, int] = 1,
|
279
|
+
batch_size = None,
|
251
280
|
dim_head = None,
|
252
281
|
heads = 1,
|
253
282
|
model: Module | None = None,
|
254
283
|
store_memory_loss_fn: Callable = default_loss_fn,
|
255
284
|
adaptive_step_transform: Callable | None = None,
|
256
|
-
default_step_transform_max_lr =
|
285
|
+
default_step_transform_max_lr = 1.,
|
257
286
|
per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
|
258
287
|
max_mem_layer_modulation = 1e1, # max of 10.
|
259
288
|
attn_pool_chunks = False,
|
@@ -274,6 +303,13 @@ class NeuralMemory(Module):
|
|
274
303
|
|
275
304
|
self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
|
276
305
|
|
306
|
+
# batch size
|
307
|
+
|
308
|
+
if exists(batch_size):
|
309
|
+
assert divisible_by(batch_size, self.store_chunk_size)
|
310
|
+
|
311
|
+
self.batch_size = batch_size
|
312
|
+
|
277
313
|
# associative scan
|
278
314
|
|
279
315
|
self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan)
|
@@ -342,14 +378,13 @@ class NeuralMemory(Module):
|
|
342
378
|
pred = functional_call(self.memory_model, params, inputs)
|
343
379
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
344
380
|
weighted_loss = loss * loss_weights
|
345
|
-
return weighted_loss.sum()
|
381
|
+
return weighted_loss.sum()
|
346
382
|
|
347
383
|
# two functions
|
348
384
|
|
349
|
-
grad_fn = grad(forward_and_loss
|
385
|
+
grad_fn = grad(forward_and_loss)
|
350
386
|
|
351
|
-
self.per_sample_grad_fn = vmap(grad_fn, in_dims = (
|
352
|
-
self.per_sample_grad_fn_expanded_weights = vmap(grad_fn, in_dims = (0,) * 4)
|
387
|
+
self.per_sample_grad_fn = vmap(grad_fn, in_dims = (0, 0, 0, 0))
|
353
388
|
|
354
389
|
# queries for retrieving from the model
|
355
390
|
|
@@ -417,56 +452,58 @@ class NeuralMemory(Module):
|
|
417
452
|
|
418
453
|
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
419
454
|
|
420
|
-
def init_weights(
|
455
|
+
def init_weights(
|
456
|
+
self,
|
457
|
+
batch,
|
458
|
+
):
|
421
459
|
weights = TensorDict(dict(self.memory_model.named_parameters()))
|
460
|
+
weights = repeat_dict_values(weights, '... -> bh ...', bh = batch * self.heads)
|
422
461
|
return weights
|
423
462
|
|
463
|
+
def init_momentum(
|
464
|
+
self,
|
465
|
+
batch,
|
466
|
+
):
|
467
|
+
weights = TensorDict(dict(self.memory_model.named_parameters()))
|
468
|
+
zeros = weights.clone().zero_()
|
469
|
+
zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
|
470
|
+
return zeros
|
471
|
+
|
424
472
|
def store_memories(
|
425
473
|
self,
|
426
474
|
seq,
|
427
|
-
weights: dict[str, Tensor],
|
475
|
+
weights: dict[str, Tensor] | None = None,
|
428
476
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
429
|
-
|
430
|
-
return_aux_kv_loss = False,
|
431
|
-
chunk_size = None,
|
477
|
+
seq_index = 0
|
432
478
|
):
|
433
|
-
seq_len, heads, chunk_size = seq.shape[
|
434
|
-
|
435
|
-
# handle edge case
|
436
|
-
|
437
|
-
if seq_len < chunk_size:
|
438
|
-
return TensorDict(weights).clone().zero_(), self.zero
|
439
|
-
|
440
|
-
seq = self.store_norm(seq)
|
479
|
+
batch, seq_len, heads, chunk_size = *seq.shape[:2], self.heads, self.store_chunk_size
|
441
480
|
|
442
481
|
# curtail sequence by multiple of the chunk size
|
443
482
|
# only a complete chunk of the sequence provides the memory for the next chunk
|
444
483
|
|
445
484
|
round_down_seq_len = round_down_multiple(seq_len, chunk_size)
|
485
|
+
num_chunks = round_down_seq_len // chunk_size
|
446
486
|
|
447
487
|
seq, remainder = seq[:, :round_down_seq_len], seq[:, round_down_seq_len:]
|
448
488
|
|
449
|
-
|
450
|
-
|
451
|
-
per_sample_grad_fn = self.per_sample_grad_fn
|
489
|
+
next_seq_len_index = seq_index + round_down_seq_len
|
452
490
|
|
491
|
+
# init weights if needed
|
453
492
|
# weights of the memory network
|
454
493
|
|
494
|
+
if not exists(weights):
|
495
|
+
weights = self.init_weights(batch)
|
496
|
+
|
455
497
|
weights = TensorDict(weights)
|
456
498
|
|
457
499
|
# allow for neural memory of a previous layer to influence surprise of current layer
|
458
500
|
|
459
|
-
weights_for_surprise = weights
|
460
|
-
|
461
|
-
if exists(prev_layer_updates):
|
462
|
-
prev_layer_updates = TensorDict(prev_layer_updates)
|
463
|
-
|
464
|
-
weights_for_surprise = weights_for_surprise + prev_layer_updates
|
465
|
-
|
466
|
-
per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
|
501
|
+
weights_for_surprise = repeat_dict_values(weights, 'b ... -> b n ...', n = num_chunks)
|
467
502
|
|
468
503
|
# derive learned hparams for optimization of memory network
|
469
504
|
|
505
|
+
seq = self.store_norm(seq)
|
506
|
+
|
470
507
|
adaptive_lr = self.to_adaptive_step(seq)
|
471
508
|
adaptive_lr = self.adaptive_step_transform(adaptive_lr)
|
472
509
|
|
@@ -474,7 +511,7 @@ class NeuralMemory(Module):
|
|
474
511
|
|
475
512
|
decay_factor = self.to_decay_factor(chunked_seq).sigmoid()
|
476
513
|
|
477
|
-
need_layer_lr_mod = exists(self.to_layer_modulation)
|
514
|
+
need_layer_lr_mod = exists(self.to_layer_modulation) and num_chunks > 0
|
478
515
|
has_momentum = exists(self.to_momentum)
|
479
516
|
|
480
517
|
if has_momentum:
|
@@ -505,12 +542,11 @@ class NeuralMemory(Module):
|
|
505
542
|
|
506
543
|
# flatten batch and time if surprise depends on previous layer memory model
|
507
544
|
|
508
|
-
|
509
|
-
weights_for_surprise = weights_for_surprise.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
545
|
+
weights_for_surprise = rearrange_dict_values(weights_for_surprise, 'b n ... -> (b n) ...')
|
510
546
|
|
511
547
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
512
548
|
|
513
|
-
grads
|
549
|
+
grads = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
|
514
550
|
|
515
551
|
grads = TensorDict(grads)
|
516
552
|
|
@@ -521,7 +557,7 @@ class NeuralMemory(Module):
|
|
521
557
|
|
522
558
|
# restore batch and sequence dimension
|
523
559
|
|
524
|
-
grads = grads
|
560
|
+
grads = rearrange_dict_values(grads, '(b n) ... -> b n ...', b = batch * heads)
|
525
561
|
|
526
562
|
# maybe per layer modulation
|
527
563
|
|
@@ -535,19 +571,25 @@ class NeuralMemory(Module):
|
|
535
571
|
# past states
|
536
572
|
|
537
573
|
if not exists(past_state):
|
538
|
-
empty_dict = {key: None for key in weights.keys()}
|
539
|
-
|
540
574
|
# minibatch_init_weight corresponds to W0 in figure 7 of TTT paper
|
541
575
|
|
542
576
|
minibatch_init_weight = weights
|
577
|
+
init_momentum = self.init_momentum(batch)
|
543
578
|
|
544
|
-
|
545
|
-
minibatch_init_weight = weights.apply(lambda t: repeat(t, '... -> b 1 (...)', b = batch * heads))
|
546
|
-
|
547
|
-
past_state = (minibatch_init_weight, empty_dict)
|
579
|
+
past_state = (minibatch_init_weight, init_momentum)
|
548
580
|
|
549
581
|
past_last_update, past_last_momentum = past_state
|
550
582
|
|
583
|
+
# early return if sequence length less than chunk size
|
584
|
+
|
585
|
+
if num_chunks == 0:
|
586
|
+
updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...')
|
587
|
+
next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, past_state, updates)
|
588
|
+
|
589
|
+
output = (updates, next_store_state)
|
590
|
+
|
591
|
+
return output
|
592
|
+
|
551
593
|
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
552
594
|
|
553
595
|
next_momentum = TensorDict() if has_momentum else None
|
@@ -558,8 +600,6 @@ class NeuralMemory(Module):
|
|
558
600
|
|
559
601
|
for (param_name, surprise), (_, last_update), (_, last_momentum) in zip(surprises.items(), past_last_update.items(), past_last_momentum.items()):
|
560
602
|
|
561
|
-
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
|
562
|
-
|
563
603
|
update = surprise
|
564
604
|
|
565
605
|
# derive momentum with associative scan - eq (10)
|
@@ -571,62 +611,51 @@ class NeuralMemory(Module):
|
|
571
611
|
|
572
612
|
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
573
613
|
|
574
|
-
update = self.assoc_scan(1. - decay_factor, update, prev = last_update)
|
614
|
+
update = self.assoc_scan(1. - decay_factor, update, prev = last_update, remove_prev = False)
|
575
615
|
next_last_update[param_name] = update[:, -1]
|
576
616
|
|
577
|
-
updates[param_name] =
|
617
|
+
updates[param_name] = update
|
578
618
|
|
579
619
|
if has_momentum:
|
580
|
-
next_momentum[param_name] =
|
620
|
+
next_momentum[param_name] = momentum
|
581
621
|
|
582
622
|
# determine next state for the storing of memories
|
583
623
|
|
584
624
|
next_state = (next_last_update, next_last_momentum)
|
585
625
|
|
586
|
-
next_store_state = NeuralMemCache(
|
626
|
+
next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, next_state, updates)
|
587
627
|
|
588
628
|
# returns
|
589
629
|
|
590
630
|
output = (updates, next_store_state)
|
591
631
|
|
592
|
-
|
593
|
-
return output
|
594
|
-
|
595
|
-
return output, aux_kv_recon_loss.mean()
|
632
|
+
return output
|
596
633
|
|
597
634
|
def retrieve_memories(
|
598
635
|
self,
|
599
636
|
seq,
|
600
637
|
past_weights: dict[str, Tensor],
|
601
|
-
chunk_size = None,
|
602
|
-
prev_layer_updates: dict[str, Tensor] | None = None
|
603
638
|
):
|
604
|
-
chunk_size =
|
639
|
+
chunk_size = self.retrieve_chunk_size
|
605
640
|
batch, seq_len = seq.shape[:2]
|
606
641
|
|
607
642
|
seq = self.retrieve_norm(seq)
|
608
643
|
|
609
|
-
assert seq_len >= chunk_size, 'must be handled outside of retrieve'
|
610
|
-
|
611
644
|
needs_pad = chunk_size > 1
|
612
645
|
|
613
|
-
|
614
|
-
|
615
|
-
seq_len_plus_one = seq.shape[-2]
|
646
|
+
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
647
|
+
seq_len_plus_one = seq.shape[-2]
|
616
648
|
|
617
|
-
|
649
|
+
next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
|
618
650
|
|
619
|
-
|
620
|
-
|
651
|
+
padding = next_seq_len - seq_len_plus_one
|
652
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
621
653
|
|
622
654
|
# the parameters of the memory model stores the memories of the key / values
|
623
655
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
624
656
|
|
625
657
|
curr_weights = TensorDict(past_weights)
|
626
658
|
|
627
|
-
if exists(prev_layer_updates):
|
628
|
-
curr_weights = curr_weights + TensorDict(prev_layer_updates)
|
629
|
-
|
630
659
|
# sequence Float['b n d'] to queries
|
631
660
|
|
632
661
|
queries = self.to_queries(seq)
|
@@ -642,7 +671,7 @@ class NeuralMemory(Module):
|
|
642
671
|
# fetch values from memory model
|
643
672
|
|
644
673
|
if dict_get_shape(curr_weights) != self.init_weight_shape:
|
645
|
-
curr_weights = curr_weights
|
674
|
+
curr_weights = rearrange_dict_values(curr_weights, 'b n ... -> (b n) ...')
|
646
675
|
|
647
676
|
queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
|
648
677
|
|
@@ -669,8 +698,7 @@ class NeuralMemory(Module):
|
|
669
698
|
|
670
699
|
# restore, pad with empty memory embed
|
671
700
|
|
672
|
-
|
673
|
-
values = values[:, 1:(seq_len + 1)]
|
701
|
+
values = values[:, 1:(seq_len + 1)]
|
674
702
|
|
675
703
|
return values
|
676
704
|
|
@@ -678,16 +706,14 @@ class NeuralMemory(Module):
|
|
678
706
|
def forward_inference(
|
679
707
|
self,
|
680
708
|
token: Tensor,
|
681
|
-
state = None,
|
682
|
-
prev_layer_updates: dict[str, Tensor] | None = None,
|
709
|
+
state: NeuralMemCache | None = None,
|
683
710
|
):
|
684
|
-
|
685
711
|
# unpack previous state
|
686
712
|
|
687
713
|
if not exists(state):
|
688
|
-
state = (0, None, None, None)
|
714
|
+
state = (0, None, None, None, None)
|
689
715
|
|
690
|
-
seq_index, cache_store_seq, past_states, updates = state
|
716
|
+
seq_index, weights, cache_store_seq, past_states, updates = state
|
691
717
|
|
692
718
|
curr_seq_len = seq_index + 1
|
693
719
|
batch = token.shape[0]
|
@@ -695,9 +721,7 @@ class NeuralMemory(Module):
|
|
695
721
|
if token.ndim == 2:
|
696
722
|
token = rearrange(token, 'b d -> b 1 d')
|
697
723
|
|
698
|
-
|
699
|
-
|
700
|
-
weights = self.init_weights()
|
724
|
+
assert token.shape[1] == 1
|
701
725
|
|
702
726
|
# increment the sequence cache which is at most the chunk size
|
703
727
|
|
@@ -708,7 +732,7 @@ class NeuralMemory(Module):
|
|
708
732
|
if curr_seq_len < self.chunk_size:
|
709
733
|
retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
|
710
734
|
|
711
|
-
output = retrieve, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
|
735
|
+
output = retrieve, NeuralMemCache(curr_seq_len, weights, cache_store_seq, past_states, updates)
|
712
736
|
|
713
737
|
return output
|
714
738
|
|
@@ -719,21 +743,16 @@ class NeuralMemory(Module):
|
|
719
743
|
|
720
744
|
if not exists(updates):
|
721
745
|
updates = weights.clone().zero_()
|
722
|
-
updates = updates
|
746
|
+
updates = repeat_dict_values(updates, '... -> b 1 ...', b = batch)
|
723
747
|
else:
|
724
748
|
updates = updates.apply(lambda t: t[:, -1:])
|
725
749
|
|
726
|
-
if exists(prev_layer_updates):
|
727
|
-
prev_layer_updates = TensorDict(prev_layer_updates)
|
728
|
-
prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
|
729
|
-
|
730
750
|
if store_seq_cache_len == self.chunk_size:
|
731
751
|
|
732
752
|
next_updates, store_state = self.store_memories(
|
733
753
|
cache_store_seq,
|
734
754
|
weights,
|
735
755
|
past_state = past_states,
|
736
|
-
prev_layer_updates = prev_layer_updates,
|
737
756
|
)
|
738
757
|
|
739
758
|
updates = next_updates
|
@@ -746,7 +765,7 @@ class NeuralMemory(Module):
|
|
746
765
|
|
747
766
|
# next state tuple
|
748
767
|
|
749
|
-
next_store_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
|
768
|
+
next_store_state = NeuralMemCache(curr_seq_len, weights, cache_store_seq, next_states, updates)
|
750
769
|
|
751
770
|
return retrieved, next_store_state
|
752
771
|
|
@@ -754,63 +773,99 @@ class NeuralMemory(Module):
|
|
754
773
|
self,
|
755
774
|
seq,
|
756
775
|
store_seq = None,
|
757
|
-
|
758
|
-
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
759
|
-
return_aux_kv_loss = False,
|
760
|
-
chunk_size = None,
|
761
|
-
store_chunk_size = None,
|
762
|
-
return_next_state = False,
|
763
|
-
prev_layer_updates: dict[str, Tensor] | None = None
|
776
|
+
state: NeuralMemCache | None = None,
|
764
777
|
):
|
765
|
-
|
778
|
+
if not exists(state):
|
779
|
+
state = (0, None, None, None, None)
|
766
780
|
|
767
|
-
|
768
|
-
mem_model_weights = self.init_weights()
|
781
|
+
seq_index, weights, cache_store_seq, past_state, updates = state
|
769
782
|
|
770
|
-
|
771
|
-
retrieved = self.retrieve_memories(seq, mem_model_weights, chunk_size = 1)
|
783
|
+
assert not exists(cache_store_seq) or is_empty_tensor(cache_store_seq)
|
772
784
|
|
773
|
-
|
785
|
+
# store
|
774
786
|
|
775
|
-
|
787
|
+
store_seq = default(store_seq, seq)
|
776
788
|
|
777
|
-
|
778
|
-
return out
|
789
|
+
# functions
|
779
790
|
|
780
|
-
|
791
|
+
# compute split sizes of sequence
|
792
|
+
# for now manually update weights to last update at the correct boundaries
|
781
793
|
|
782
|
-
|
794
|
+
store_seq_len, chunk_size, batch_size = store_seq.shape[-2], self.chunk_size, self.batch_size
|
783
795
|
|
784
|
-
|
796
|
+
need_update_weights = exists(batch_size)
|
785
797
|
|
786
|
-
|
787
|
-
store_seq,
|
788
|
-
mem_model_weights,
|
789
|
-
chunk_size = store_chunk_size,
|
790
|
-
prev_layer_updates = prev_layer_updates,
|
791
|
-
return_aux_kv_loss = True
|
792
|
-
)
|
798
|
+
# determine split sizes and when to update
|
793
799
|
|
794
|
-
|
800
|
+
if need_update_weights:
|
801
|
+
update_after_final_store = divisible_by(seq_index + store_seq_len, batch_size)
|
802
|
+
|
803
|
+
seq_range = torch.arange(store_seq_len) + seq_index + 1
|
804
|
+
batch_boundary = divisible_by(seq_range, batch_size)
|
805
|
+
|
806
|
+
indices = seq_range[batch_boundary] - seq_index
|
807
|
+
|
808
|
+
indices = F.pad(indices, (1, 0), value = 0)
|
809
|
+
|
810
|
+
if indices[-1] != store_seq_len:
|
811
|
+
indices = F.pad(indices, (0, 1), value = store_seq_len)
|
812
|
+
|
813
|
+
split_sizes = (indices[1:] - indices[:-1]).tolist()
|
814
|
+
|
815
|
+
assert sum(split_sizes) == store_seq_len
|
816
|
+
else:
|
817
|
+
split_sizes = (store_seq_len,)
|
818
|
+
update_after_final_store = False
|
795
819
|
|
796
|
-
|
820
|
+
# accumulate updates
|
797
821
|
|
798
|
-
|
799
|
-
if exists(prev_layer_updates):
|
800
|
-
prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
822
|
+
updates = None
|
801
823
|
|
802
|
-
|
824
|
+
def accum_updates(past_updates, future_updates):
|
825
|
+
if not exists(past_updates):
|
826
|
+
return future_updates
|
827
|
+
|
828
|
+
return TensorDict({param_name: cat((past_update[:, :-1], future_update), dim = 1) for (param_name, past_update), (_, future_update) in zip(past_updates.items(), future_updates.items())})
|
829
|
+
|
830
|
+
# loop through chunks of store sequences
|
831
|
+
|
832
|
+
store_seqs = store_seq.split(split_sizes, dim = -2)
|
833
|
+
|
834
|
+
for ind, store_seq_chunk in enumerate(store_seqs):
|
835
|
+
is_last = ind == (len(store_seqs) - 1)
|
836
|
+
|
837
|
+
# store
|
838
|
+
|
839
|
+
next_updates, next_neural_mem_state = self.store_memories(
|
840
|
+
store_seq_chunk,
|
841
|
+
weights,
|
842
|
+
seq_index = seq_index,
|
843
|
+
past_state = past_state,
|
844
|
+
)
|
845
|
+
|
846
|
+
seq_index = next_neural_mem_state.seq_index
|
847
|
+
past_state = next_neural_mem_state.states
|
848
|
+
|
849
|
+
updates = accum_updates(updates, next_updates)
|
850
|
+
|
851
|
+
if is_last and not update_after_final_store:
|
852
|
+
continue
|
853
|
+
|
854
|
+
# update weights once batch size is fulfilled
|
855
|
+
|
856
|
+
last_update, _ = past_state
|
857
|
+
|
858
|
+
weights = last_update
|
859
|
+
|
860
|
+
next_neural_mem_state = list(next_neural_mem_state)
|
861
|
+
next_neural_mem_state[1] = last_update
|
862
|
+
next_neural_mem_state = NeuralMemCache(*next_neural_mem_state)
|
863
|
+
|
864
|
+
# retrieve
|
803
865
|
|
804
866
|
retrieved = self.retrieve_memories(
|
805
867
|
seq,
|
806
|
-
updates
|
807
|
-
chunk_size = chunk_size,
|
808
|
-
prev_layer_updates = prev_layer_updates
|
868
|
+
updates
|
809
869
|
)
|
810
870
|
|
811
|
-
|
812
|
-
|
813
|
-
if not return_aux_kv_loss:
|
814
|
-
return output
|
815
|
-
|
816
|
-
return output, aux_kv_recon_loss
|
871
|
+
return retrieved, next_neural_mem_state
|
@@ -35,13 +35,11 @@ NEURAL_MEM_GATE_ATTN_OUTPUT = False
|
|
35
35
|
NEURAL_MEM_MOMENTUM = True
|
36
36
|
NEURAL_MEM_QK_NORM = True
|
37
37
|
WINDOW_SIZE = 32
|
38
|
-
NEURAL_MEM_SEGMENT_LEN =
|
38
|
+
NEURAL_MEM_SEGMENT_LEN = 2 # set smaller for more granularity for learning rate / momentum etc
|
39
|
+
NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence
|
39
40
|
SLIDING_WINDOWS = True
|
40
|
-
WEIGHT_TIE_MEMORY_MODEL = False # set to have memory MLP shared across layers
|
41
|
-
PREV_MEM_UPDATE_FOR_WEIGHTS = True,
|
42
41
|
STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
|
43
42
|
MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
|
44
|
-
KV_RECON_LOSS_WEIGHT = 1.
|
45
43
|
|
46
44
|
# experiment related
|
47
45
|
|
@@ -86,12 +84,10 @@ model = MemoryAsContextTransformer(
|
|
86
84
|
num_longterm_mem_tokens = NUM_LONGTERM_MEM,
|
87
85
|
neural_memory_layers = NEURAL_MEM_LAYERS,
|
88
86
|
neural_memory_segment_len = NEURAL_MEM_SEGMENT_LEN,
|
87
|
+
neural_memory_batch_size = NEURAL_MEM_BATCH_SIZE,
|
89
88
|
neural_mem_gate_attn_output = NEURAL_MEM_GATE_ATTN_OUTPUT,
|
90
|
-
aux_kv_recon_loss_weight = KV_RECON_LOSS_WEIGHT,
|
91
89
|
use_flex_attn = USE_FLEX_ATTN,
|
92
90
|
sliding_window_attn = SLIDING_WINDOWS,
|
93
|
-
weight_tie_memory_model = WEIGHT_TIE_MEMORY_MODEL,
|
94
|
-
prev_neural_mem_update_for_weights = PREV_MEM_UPDATE_FOR_WEIGHTS,
|
95
91
|
neural_memory_model = MemoryMLP(
|
96
92
|
dim = 64,
|
97
93
|
depth = NEURAL_MEMORY_DEPTH
|
@@ -143,20 +139,20 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
|
|
143
139
|
model.train()
|
144
140
|
|
145
141
|
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
146
|
-
loss
|
142
|
+
loss = model(next(train_loader), return_loss = True)
|
147
143
|
loss.backward()
|
148
144
|
|
149
|
-
print(f'training loss: {
|
145
|
+
print(f'training loss: {loss.item()}')
|
150
146
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
151
147
|
optim.step()
|
152
148
|
optim.zero_grad()
|
153
|
-
wandb.log(dict(loss =
|
149
|
+
wandb.log(dict(loss = loss.item()))
|
154
150
|
|
155
151
|
if i % VALIDATE_EVERY == 0:
|
156
152
|
model.eval()
|
157
153
|
with torch.no_grad():
|
158
|
-
loss
|
159
|
-
print(f'validation loss: {
|
154
|
+
loss = model(next(val_loader), return_loss = True)
|
155
|
+
print(f'validation loss: {loss.item()}')
|
160
156
|
|
161
157
|
if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
|
162
158
|
model.eval()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|