titans-pytorch 0.3.2__tar.gz → 0.3.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/PKG-INFO +1 -1
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/pyproject.toml +1 -1
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/tests/test_titans.py +29 -26
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/titans_pytorch/associative_scan.py +87 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/titans_pytorch/memory_models.py +1 -1
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/titans_pytorch/neural_memory.py +66 -101
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/train_mac.py +19 -5
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/.gitignore +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/LICENSE +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/README.md +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/data/README.md +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/data/enwik8.gz +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/fig1.png +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/fig2.png +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.4}/titans_pytorch/mac_transformer.py +0 -0
@@ -31,6 +31,7 @@ def torch_default_dtype(dtype):
|
|
31
31
|
@pytest.mark.parametrize('qk_rmsnorm', (False, True))
|
32
32
|
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
33
33
|
@pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
|
34
|
+
@pytest.mark.parametrize('per_head_learned_parameters', (False, True))
|
34
35
|
def test_titans(
|
35
36
|
seq_len,
|
36
37
|
silu,
|
@@ -39,10 +40,11 @@ def test_titans(
|
|
39
40
|
momentum,
|
40
41
|
qk_rmsnorm,
|
41
42
|
max_grad_norm,
|
42
|
-
per_parameter_lr_modulation
|
43
|
+
per_parameter_lr_modulation,
|
44
|
+
per_head_learned_parameters
|
43
45
|
):
|
44
46
|
mem = NeuralMemory(
|
45
|
-
dim =
|
47
|
+
dim = 16,
|
46
48
|
chunk_size = chunk_size,
|
47
49
|
activation = nn.SiLU() if silu else None,
|
48
50
|
attn_pool_chunks = attn_pool_chunks,
|
@@ -50,9 +52,10 @@ def test_titans(
|
|
50
52
|
momentum = momentum,
|
51
53
|
qk_rmsnorm = qk_rmsnorm,
|
52
54
|
per_parameter_lr_modulation = per_parameter_lr_modulation,
|
55
|
+
per_head_learned_parameters = per_head_learned_parameters
|
53
56
|
)
|
54
57
|
|
55
|
-
seq = torch.randn(2, seq_len,
|
58
|
+
seq = torch.randn(2, seq_len, 16)
|
56
59
|
retrieved, _ = mem(seq)
|
57
60
|
|
58
61
|
assert seq.shape == retrieved.shape
|
@@ -61,14 +64,14 @@ def test_titans_attn_memory():
|
|
61
64
|
from titans_pytorch.memory_models import MemoryAttention
|
62
65
|
|
63
66
|
mem = NeuralMemory(
|
64
|
-
dim =
|
67
|
+
dim = 16,
|
65
68
|
chunk_size = 64,
|
66
69
|
model = MemoryAttention(
|
67
|
-
dim =
|
70
|
+
dim = 16
|
68
71
|
)
|
69
72
|
)
|
70
73
|
|
71
|
-
seq = torch.randn(2, 1024,
|
74
|
+
seq = torch.randn(2, 1024, 16)
|
72
75
|
retrieved, _ = mem(seq)
|
73
76
|
|
74
77
|
assert seq.shape == retrieved.shape
|
@@ -78,14 +81,14 @@ def test_neural_mem_chaining_chunks(
|
|
78
81
|
gated_transition
|
79
82
|
):
|
80
83
|
mem = NeuralMemory(
|
81
|
-
dim =
|
82
|
-
dim_head =
|
84
|
+
dim = 16,
|
85
|
+
dim_head = 16,
|
83
86
|
heads = 2,
|
84
87
|
chunk_size = 16,
|
85
88
|
gated_transition = gated_transition
|
86
89
|
)
|
87
90
|
|
88
|
-
seq = torch.randn(2, 48,
|
91
|
+
seq = torch.randn(2, 48, 16)
|
89
92
|
|
90
93
|
parallel_retrieved, state = mem(seq)
|
91
94
|
|
@@ -99,21 +102,21 @@ def test_neural_mem_chaining_chunks(
|
|
99
102
|
|
100
103
|
def test_neural_mem_chaining_with_weight_residual():
|
101
104
|
mem = NeuralMemory(
|
102
|
-
dim =
|
103
|
-
dim_head =
|
105
|
+
dim = 16,
|
106
|
+
dim_head = 16,
|
104
107
|
heads = 2,
|
105
108
|
chunk_size = 64
|
106
109
|
)
|
107
110
|
|
108
111
|
mem2 = NeuralMemory(
|
109
|
-
dim =
|
110
|
-
dim_head =
|
112
|
+
dim = 16,
|
113
|
+
dim_head = 16,
|
111
114
|
heads = 2,
|
112
115
|
chunk_size = 64,
|
113
116
|
accept_weight_residual = True
|
114
117
|
)
|
115
118
|
|
116
|
-
seq = torch.randn(2, 256,
|
119
|
+
seq = torch.randn(2, 256, 16)
|
117
120
|
|
118
121
|
seq, state = mem(seq)
|
119
122
|
|
@@ -124,18 +127,18 @@ def test_neural_mem_chaining_with_weight_residual():
|
|
124
127
|
first_retrieved, state1 = mem2(seq_first, prev_weights = state.updates)
|
125
128
|
second_retrieved, state2 = mem2(seq_second, state = state1, prev_weights = state.updates)
|
126
129
|
|
127
|
-
assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-
|
130
|
+
assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-5)
|
128
131
|
|
129
132
|
def test_neural_mem_chaining_with_batch_size():
|
130
133
|
mem = NeuralMemory(
|
131
|
-
dim =
|
132
|
-
dim_head =
|
134
|
+
dim = 16,
|
135
|
+
dim_head = 16,
|
133
136
|
heads = 2,
|
134
137
|
chunk_size = 16,
|
135
138
|
batch_size = 64
|
136
139
|
)
|
137
140
|
|
138
|
-
seq = torch.randn(2, 112,
|
141
|
+
seq = torch.randn(2, 112, 16)
|
139
142
|
|
140
143
|
parallel_retrieved, state = mem(seq)
|
141
144
|
|
@@ -169,7 +172,7 @@ def test_mac(
|
|
169
172
|
):
|
170
173
|
transformer = MemoryAsContextTransformer(
|
171
174
|
num_tokens = 256,
|
172
|
-
dim =
|
175
|
+
dim = 16,
|
173
176
|
depth = 2,
|
174
177
|
num_persist_mem_tokens = num_persist_mem_tokens,
|
175
178
|
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
@@ -201,7 +204,7 @@ def test_mac_sampling(
|
|
201
204
|
):
|
202
205
|
transformer = MemoryAsContextTransformer(
|
203
206
|
num_tokens = 256,
|
204
|
-
dim =
|
207
|
+
dim = 16,
|
205
208
|
depth = 4,
|
206
209
|
segment_len = 32,
|
207
210
|
num_persist_mem_tokens = 4,
|
@@ -235,12 +238,12 @@ def test_neural_mem_inference(
|
|
235
238
|
):
|
236
239
|
|
237
240
|
mem = NeuralMemory(
|
238
|
-
dim =
|
241
|
+
dim = 16,
|
239
242
|
chunk_size = mem_chunk_size,
|
240
243
|
gated_transition = gated_transition
|
241
244
|
)
|
242
245
|
|
243
|
-
seq = torch.randn(2, seq_len,
|
246
|
+
seq = torch.randn(2, seq_len, 16)
|
244
247
|
parallel_retrieved, _ = mem(seq)
|
245
248
|
|
246
249
|
assert seq.shape == parallel_retrieved.shape
|
@@ -282,7 +285,7 @@ def test_flex(
|
|
282
285
|
pytest.skip()
|
283
286
|
|
284
287
|
attn = SegmentedAttention(
|
285
|
-
dim =
|
288
|
+
dim = 16,
|
286
289
|
segment_len = 32,
|
287
290
|
num_persist_mem_tokens = 1,
|
288
291
|
num_longterm_mem_tokens = 1,
|
@@ -290,7 +293,7 @@ def test_flex(
|
|
290
293
|
sliding = sliding
|
291
294
|
).cuda()
|
292
295
|
|
293
|
-
seq = torch.randn(1, seq_len,
|
296
|
+
seq = torch.randn(1, seq_len, 16).cuda()
|
294
297
|
|
295
298
|
out_flex, _ = attn(seq)
|
296
299
|
out_non_flex, _ = attn(seq, disable_flex_attn = True)
|
@@ -307,8 +310,8 @@ def test_assoc_scan():
|
|
307
310
|
seq_len = 128
|
308
311
|
mid_point = seq_len // 2
|
309
312
|
|
310
|
-
gates = torch.randn(2, seq_len,
|
311
|
-
inputs = torch.randn(2, seq_len,
|
313
|
+
gates = torch.randn(2, seq_len, 16).sigmoid()
|
314
|
+
inputs = torch.randn(2, seq_len, 16)
|
312
315
|
|
313
316
|
output = scan(gates, inputs)
|
314
317
|
|
@@ -3,18 +3,39 @@ from typing import Callable
|
|
3
3
|
|
4
4
|
import torch
|
5
5
|
from torch import Tensor
|
6
|
+
from torch.nn import Module
|
6
7
|
import torch.nn.functional as F
|
7
8
|
|
9
|
+
from einops import rearrange, repeat, reduce, pack, unpack
|
10
|
+
|
8
11
|
# taken from S5-pytorch repository
|
9
12
|
# https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
|
10
13
|
|
11
14
|
# helper functions
|
12
15
|
|
16
|
+
def exists(v):
|
17
|
+
return v is not None
|
18
|
+
|
19
|
+
def default(*args):
|
20
|
+
for arg in args:
|
21
|
+
if exists(arg):
|
22
|
+
return arg
|
23
|
+
return None
|
24
|
+
|
13
25
|
def pad_at_dim(t, pad, dim = -1, value = 0.):
|
14
26
|
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
15
27
|
zeros = ((0, 0) * dims_from_right)
|
16
28
|
return F.pad(t, (*zeros, *pad), value = value)
|
17
29
|
|
30
|
+
def pack_one_with_inverse(t, pattern):
|
31
|
+
packed, packed_shape = pack([t], pattern)
|
32
|
+
|
33
|
+
def inverse(out, inv_pattern = None):
|
34
|
+
inv_pattern = default(inv_pattern, pattern)
|
35
|
+
return unpack(out, packed_shape, inv_pattern)[0]
|
36
|
+
|
37
|
+
return packed, inverse
|
38
|
+
|
18
39
|
# the operator that is needed
|
19
40
|
|
20
41
|
@torch.jit.script
|
@@ -88,3 +109,69 @@ def _interleave(a, b):
|
|
88
109
|
interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
|
89
110
|
|
90
111
|
return interleaved[:, :output_axis_len]
|
112
|
+
|
113
|
+
# associative scan wrapper around naive and accelerated version
|
114
|
+
|
115
|
+
class AssocScan(Module):
|
116
|
+
def __init__(
|
117
|
+
self,
|
118
|
+
use_accelerated = False
|
119
|
+
):
|
120
|
+
super().__init__()
|
121
|
+
self.use_accelerated = use_accelerated
|
122
|
+
|
123
|
+
def forward(
|
124
|
+
self,
|
125
|
+
gates,
|
126
|
+
inputs,
|
127
|
+
prev = None,
|
128
|
+
remove_prev = None
|
129
|
+
):
|
130
|
+
remove_prev = default(remove_prev, exists(prev))
|
131
|
+
|
132
|
+
inputs, inverse_pack_weight_shape = pack_one_with_inverse(inputs, 'b n *')
|
133
|
+
gates, _ = pack_one_with_inverse(gates, 'b n *')
|
134
|
+
|
135
|
+
if exists(prev):
|
136
|
+
prev, _ = pack_one_with_inverse(prev, 'b *')
|
137
|
+
|
138
|
+
if exists(prev):
|
139
|
+
inputs, _ = pack([prev, inputs], 'b * d')
|
140
|
+
gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
|
141
|
+
|
142
|
+
if not self.use_accelerated:
|
143
|
+
_, out = associative_scan(binary_operator, (gates, inputs))
|
144
|
+
|
145
|
+
if remove_prev:
|
146
|
+
out = out[:, 1:]
|
147
|
+
|
148
|
+
return inverse_pack_weight_shape(out)
|
149
|
+
|
150
|
+
from accelerated_scan.triton import scan as triton_scan
|
151
|
+
from accelerated_scan.warp import scan as warp_scan
|
152
|
+
|
153
|
+
scan = triton_scan if gates.is_cuda else warp_scan
|
154
|
+
|
155
|
+
def accelerate_scan_fn(gates, inputs):
|
156
|
+
gates = gates.expand_as(inputs)
|
157
|
+
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
158
|
+
|
159
|
+
seq_len = gates.shape[-1]
|
160
|
+
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
161
|
+
|
162
|
+
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
163
|
+
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
164
|
+
|
165
|
+
outputs = scan(gates.contiguous(), inputs.contiguous())
|
166
|
+
|
167
|
+
outputs = outputs[..., :seq_len]
|
168
|
+
outputs = rearrange(outputs, 'b d n -> b n d')
|
169
|
+
|
170
|
+
return outputs
|
171
|
+
|
172
|
+
out = accelerate_scan_fn(gates, inputs)
|
173
|
+
|
174
|
+
if remove_prev:
|
175
|
+
out = out[:, 1:]
|
176
|
+
|
177
|
+
return inverse_pack_weight_shape(out)
|
@@ -8,16 +8,12 @@ from collections import namedtuple
|
|
8
8
|
import torch
|
9
9
|
from torch import nn, cat, tensor, Tensor
|
10
10
|
import torch.nn.functional as F
|
11
|
-
from torch.nn import Linear, Module, Parameter, ParameterList
|
11
|
+
from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict
|
12
12
|
from torch.func import functional_call, vmap, grad
|
13
13
|
|
14
14
|
from tensordict import TensorDict
|
15
15
|
|
16
|
-
from titans_pytorch.associative_scan import
|
17
|
-
associative_scan,
|
18
|
-
binary_operator,
|
19
|
-
pad_at_dim
|
20
|
-
)
|
16
|
+
from titans_pytorch.associative_scan import AssocScan
|
21
17
|
|
22
18
|
from titans_pytorch.memory_models import(
|
23
19
|
MemoryMLP
|
@@ -79,8 +75,8 @@ def safe_cat(inputs, dim = -2):
|
|
79
75
|
def is_empty_tensor(t):
|
80
76
|
return t.numel() == 0
|
81
77
|
|
82
|
-
def
|
83
|
-
return
|
78
|
+
def dict_get_value_shapes(td):
|
79
|
+
return [v.shape for k, v in td.items()]
|
84
80
|
|
85
81
|
def rearrange_dict_values(td, pattern, **kwargs):
|
86
82
|
return td.apply(lambda t: rearrange(t, pattern, **kwargs))
|
@@ -97,6 +93,11 @@ def round_down_multiple(seq, mult):
|
|
97
93
|
def round_up_multiple(seq, mult):
|
98
94
|
return math.ceil(seq / mult) * mult
|
99
95
|
|
96
|
+
def pad_at_dim(t, pad, dim = -1, value = 0.):
|
97
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
98
|
+
zeros = ((0, 0) * dims_from_right)
|
99
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
100
|
+
|
100
101
|
def pack_one_with_inverse(t, pattern):
|
101
102
|
packed, packed_shape = pack([t], pattern)
|
102
103
|
|
@@ -197,72 +198,6 @@ class AttentionPool(Module):
|
|
197
198
|
|
198
199
|
return reduce(x * attn, 'b n c d -> b n d', 'sum')
|
199
200
|
|
200
|
-
# associative scan wrapper
|
201
|
-
|
202
|
-
class AssocScan(Module):
|
203
|
-
def __init__(
|
204
|
-
self,
|
205
|
-
use_accelerated = False
|
206
|
-
):
|
207
|
-
super().__init__()
|
208
|
-
self.use_accelerated = use_accelerated
|
209
|
-
|
210
|
-
def forward(
|
211
|
-
self,
|
212
|
-
gates,
|
213
|
-
inputs,
|
214
|
-
prev = None,
|
215
|
-
remove_prev = None
|
216
|
-
):
|
217
|
-
remove_prev = default(remove_prev, exists(prev))
|
218
|
-
|
219
|
-
inputs, inverse_pack_weight_shape = pack_one_with_inverse(inputs, 'b n *')
|
220
|
-
gates, _ = pack_one_with_inverse(gates, 'b n *')
|
221
|
-
|
222
|
-
if exists(prev):
|
223
|
-
prev, _ = pack_one_with_inverse(prev, 'b *')
|
224
|
-
|
225
|
-
if exists(prev):
|
226
|
-
inputs, _ = pack([prev, inputs], 'b * d')
|
227
|
-
gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
|
228
|
-
|
229
|
-
if not self.use_accelerated:
|
230
|
-
_, out = associative_scan(binary_operator, (gates, inputs))
|
231
|
-
|
232
|
-
if remove_prev:
|
233
|
-
out = out[:, 1:]
|
234
|
-
|
235
|
-
return inverse_pack_weight_shape(out)
|
236
|
-
|
237
|
-
from accelerated_scan.triton import scan as triton_scan
|
238
|
-
from accelerated_scan.warp import scan as warp_scan
|
239
|
-
|
240
|
-
scan = triton_scan if gates.is_cuda else warp_scan
|
241
|
-
|
242
|
-
def accelerate_scan_fn(gates, inputs):
|
243
|
-
gates = gates.expand_as(inputs)
|
244
|
-
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
245
|
-
|
246
|
-
seq_len = gates.shape[-1]
|
247
|
-
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
248
|
-
|
249
|
-
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
250
|
-
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
251
|
-
|
252
|
-
outputs = scan(gates.contiguous(), inputs.contiguous())
|
253
|
-
|
254
|
-
outputs = outputs[..., :seq_len]
|
255
|
-
outputs = rearrange(outputs, 'b d n -> b n d')
|
256
|
-
|
257
|
-
return outputs
|
258
|
-
|
259
|
-
out = accelerate_scan_fn(gates, inputs)
|
260
|
-
|
261
|
-
if remove_prev:
|
262
|
-
out = out[:, 1:]
|
263
|
-
|
264
|
-
return inverse_pack_weight_shape(out)
|
265
|
-
|
266
201
|
# main neural memory
|
267
202
|
|
268
203
|
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
|
@@ -285,6 +220,7 @@ class NeuralMemory(Module):
|
|
285
220
|
default_step_transform_max_lr = 1.,
|
286
221
|
per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
|
287
222
|
max_mem_layer_modulation = 1., # max of 10.
|
223
|
+
per_head_learned_parameters = True,
|
288
224
|
attn_pool_chunks = False,
|
289
225
|
momentum = True,
|
290
226
|
pre_rmsnorm = True,
|
@@ -370,9 +306,21 @@ class NeuralMemory(Module):
|
|
370
306
|
|
371
307
|
self.memory_model = model
|
372
308
|
|
373
|
-
|
309
|
+
mem_model_params = dict(model.named_parameters())
|
310
|
+
|
311
|
+
self.num_memory_parameter_tensors = len(mem_model_params)
|
312
|
+
|
313
|
+
self.memory_model_parameter_names = [*mem_model_params.keys()]
|
314
|
+
|
315
|
+
memory_model_parameters = [*mem_model_params.values()]
|
316
|
+
|
317
|
+
if per_head_learned_parameters:
|
318
|
+
memory_model_parameters = [repeat(p, '... -> h ...', h = heads) for p in memory_model_parameters]
|
319
|
+
|
320
|
+
self.init_weight_shape = [p.shape for p in memory_model_parameters]
|
374
321
|
|
375
|
-
self.
|
322
|
+
self.memory_model_parameters = ParameterList(memory_model_parameters)
|
323
|
+
self.per_head_learned_parameters = per_head_learned_parameters
|
376
324
|
|
377
325
|
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
378
326
|
|
@@ -488,21 +436,32 @@ class NeuralMemory(Module):
|
|
488
436
|
|
489
437
|
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
490
438
|
|
439
|
+
@property
|
440
|
+
def memory_model_parameter_dict(self):
|
441
|
+
return TensorDict(dict(zip(self.memory_model_parameter_names, self.memory_model_parameters)))
|
442
|
+
|
491
443
|
def init_weights(
|
492
444
|
self,
|
493
445
|
batch,
|
494
446
|
):
|
495
|
-
|
496
|
-
|
447
|
+
if self.per_head_learned_parameters:
|
448
|
+
weights = repeat_dict_values(self.memory_model_parameter_dict, 'h ... -> (b h) ...', b = batch)
|
449
|
+
else:
|
450
|
+
weights = repeat_dict_values(self.memory_model_parameter_dict, '... -> bh ...', bh = batch * self.heads)
|
451
|
+
|
497
452
|
return weights
|
498
453
|
|
499
454
|
def init_momentum(
|
500
455
|
self,
|
501
456
|
batch,
|
502
457
|
):
|
503
|
-
|
504
|
-
|
505
|
-
|
458
|
+
zeros = self.memory_model_parameter_dict.clone().zero_()
|
459
|
+
|
460
|
+
if self.per_head_learned_parameters:
|
461
|
+
zeros = repeat_dict_values(zeros, 'h ... -> (b h) ...', b = batch)
|
462
|
+
else:
|
463
|
+
zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
|
464
|
+
|
506
465
|
return zeros
|
507
466
|
|
508
467
|
def store_memories(
|
@@ -690,16 +649,27 @@ class NeuralMemory(Module):
|
|
690
649
|
def retrieve_memories(
|
691
650
|
self,
|
692
651
|
seq,
|
693
|
-
|
694
|
-
chunk_size = None,
|
695
|
-
need_pad = True
|
652
|
+
weights: dict[str, Tensor],
|
696
653
|
):
|
697
|
-
chunk_size =
|
654
|
+
chunk_size = self.retrieve_chunk_size
|
655
|
+
|
656
|
+
weights_have_expanded_shape = dict_get_value_shapes(weights) != self.init_weight_shape
|
657
|
+
|
698
658
|
batch, seq_len = seq.shape[:2]
|
699
659
|
|
700
|
-
|
660
|
+
# auto infer single token decoding, if there are only 1 set of weights and 1 token
|
661
|
+
|
662
|
+
is_one_token = seq_len == 1
|
663
|
+
is_one_weight = (not weights_have_expanded_shape) or next(iter(weights.values())).shape[1] == 1
|
664
|
+
|
665
|
+
is_single_token_decode = is_one_token and is_one_weight
|
701
666
|
|
702
|
-
|
667
|
+
if is_single_token_decode:
|
668
|
+
chunk_size = 1
|
669
|
+
|
670
|
+
# padding related, for chunked processing
|
671
|
+
|
672
|
+
need_pad = chunk_size > 1 or not is_one_weight
|
703
673
|
|
704
674
|
if need_pad:
|
705
675
|
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
@@ -714,7 +684,11 @@ class NeuralMemory(Module):
|
|
714
684
|
# the parameters of the memory model stores the memories of the key / values
|
715
685
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
716
686
|
|
717
|
-
|
687
|
+
weights = TensorDict(weights)
|
688
|
+
|
689
|
+
# pre norm
|
690
|
+
|
691
|
+
seq = self.retrieve_norm(seq)
|
718
692
|
|
719
693
|
# sequence Float['b n d'] to queries
|
720
694
|
|
@@ -730,14 +704,14 @@ class NeuralMemory(Module):
|
|
730
704
|
|
731
705
|
# fetch values from memory model
|
732
706
|
|
733
|
-
if
|
734
|
-
|
707
|
+
if weights_have_expanded_shape:
|
708
|
+
weights = rearrange_dict_values(weights, 'b n ... -> (b n) ...')
|
735
709
|
|
736
710
|
queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
|
737
711
|
|
738
712
|
# forward functional call
|
739
713
|
|
740
|
-
values = functional_call(self.memory_model, dict(
|
714
|
+
values = functional_call(self.memory_model, dict(weights), queries)
|
741
715
|
|
742
716
|
# reconstitute batch dimension
|
743
717
|
|
@@ -885,22 +859,13 @@ class NeuralMemory(Module):
|
|
885
859
|
|
886
860
|
# retrieve
|
887
861
|
|
888
|
-
need_pad = True
|
889
|
-
retrieve_chunk_size = None
|
890
|
-
|
891
862
|
if is_single_token:
|
892
|
-
retrieve_chunk_size = 1
|
893
|
-
need_pad = False
|
894
|
-
|
895
863
|
last_update, _ = next_neural_mem_state.states
|
896
|
-
|
897
864
|
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
898
865
|
|
899
866
|
retrieved = self.retrieve_memories(
|
900
867
|
seq,
|
901
|
-
updates
|
902
|
-
chunk_size = retrieve_chunk_size,
|
903
|
-
need_pad = need_pad,
|
868
|
+
updates
|
904
869
|
)
|
905
870
|
|
906
871
|
return retrieved, next_neural_mem_state
|
@@ -10,7 +10,11 @@ from torch.utils.data import DataLoader, Dataset
|
|
10
10
|
|
11
11
|
from adam_atan2_pytorch import AdoptAtan2
|
12
12
|
|
13
|
-
from titans_pytorch import
|
13
|
+
from titans_pytorch import (
|
14
|
+
MemoryAsContextTransformer,
|
15
|
+
MemoryMLP,
|
16
|
+
MemoryAttention
|
17
|
+
)
|
14
18
|
|
15
19
|
# constants
|
16
20
|
|
@@ -35,6 +39,7 @@ NEURAL_MEM_GATE_ATTN_OUTPUT = False
|
|
35
39
|
NEURAL_MEM_MOMENTUM = True
|
36
40
|
NEURAL_MEM_QK_NORM = True
|
37
41
|
NEURAL_MEM_MAX_LR = 1e-1
|
42
|
+
USE_MEM_ATTENTION_MODEL = False
|
38
43
|
WINDOW_SIZE = 32
|
39
44
|
NEURAL_MEM_SEGMENT_LEN = 4 # set smaller for more granularity for learning rate / momentum etc
|
40
45
|
NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence
|
@@ -75,6 +80,18 @@ def decode_token(token):
|
|
75
80
|
def decode_tokens(tokens):
|
76
81
|
return ''.join(list(map(decode_token, tokens)))
|
77
82
|
|
83
|
+
# memory model
|
84
|
+
|
85
|
+
if USE_MEM_ATTENTION_MODEL:
|
86
|
+
neural_memory_model = MemoryAttention(
|
87
|
+
dim = 64
|
88
|
+
)
|
89
|
+
else:
|
90
|
+
neural_memory_model = MemoryMLP(
|
91
|
+
dim = 64,
|
92
|
+
depth = NEURAL_MEMORY_DEPTH
|
93
|
+
)
|
94
|
+
|
78
95
|
# instantiate memory-as-context transformer
|
79
96
|
|
80
97
|
model = MemoryAsContextTransformer(
|
@@ -91,10 +108,7 @@ model = MemoryAsContextTransformer(
|
|
91
108
|
neural_mem_weight_residual = NEURAL_MEM_WEIGHT_RESIDUAL,
|
92
109
|
use_flex_attn = USE_FLEX_ATTN,
|
93
110
|
sliding_window_attn = SLIDING_WINDOWS,
|
94
|
-
neural_memory_model =
|
95
|
-
dim = 64,
|
96
|
-
depth = NEURAL_MEMORY_DEPTH
|
97
|
-
),
|
111
|
+
neural_memory_model = neural_memory_model,
|
98
112
|
neural_memory_kwargs = dict(
|
99
113
|
dim_head = 64,
|
100
114
|
heads = 4,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|