titans-pytorch 0.3.3__tar.gz → 0.3.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/PKG-INFO +1 -1
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/pyproject.toml +1 -1
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/tests/test_titans.py +4 -1
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/titans_pytorch/associative_scan.py +87 -0
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/titans_pytorch/memory_models.py +23 -13
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/titans_pytorch/neural_memory.py +47 -83
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/train_mac.py +19 -5
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/.gitignore +0 -0
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/LICENSE +0 -0
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/README.md +0 -0
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/data/README.md +0 -0
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/data/enwik8.gz +0 -0
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/fig1.png +0 -0
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/fig2.png +0 -0
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.3.3 → titans_pytorch-0.3.5}/titans_pytorch/mac_transformer.py +0 -0
@@ -31,6 +31,7 @@ def torch_default_dtype(dtype):
|
|
31
31
|
@pytest.mark.parametrize('qk_rmsnorm', (False, True))
|
32
32
|
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
33
33
|
@pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
|
34
|
+
@pytest.mark.parametrize('per_head_learned_parameters', (False, True))
|
34
35
|
def test_titans(
|
35
36
|
seq_len,
|
36
37
|
silu,
|
@@ -39,7 +40,8 @@ def test_titans(
|
|
39
40
|
momentum,
|
40
41
|
qk_rmsnorm,
|
41
42
|
max_grad_norm,
|
42
|
-
per_parameter_lr_modulation
|
43
|
+
per_parameter_lr_modulation,
|
44
|
+
per_head_learned_parameters
|
43
45
|
):
|
44
46
|
mem = NeuralMemory(
|
45
47
|
dim = 16,
|
@@ -50,6 +52,7 @@ def test_titans(
|
|
50
52
|
momentum = momentum,
|
51
53
|
qk_rmsnorm = qk_rmsnorm,
|
52
54
|
per_parameter_lr_modulation = per_parameter_lr_modulation,
|
55
|
+
per_head_learned_parameters = per_head_learned_parameters
|
53
56
|
)
|
54
57
|
|
55
58
|
seq = torch.randn(2, seq_len, 16)
|
@@ -3,18 +3,39 @@ from typing import Callable
|
|
3
3
|
|
4
4
|
import torch
|
5
5
|
from torch import Tensor
|
6
|
+
from torch.nn import Module
|
6
7
|
import torch.nn.functional as F
|
7
8
|
|
9
|
+
from einops import rearrange, repeat, reduce, pack, unpack
|
10
|
+
|
8
11
|
# taken from S5-pytorch repository
|
9
12
|
# https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
|
10
13
|
|
11
14
|
# helper functions
|
12
15
|
|
16
|
+
def exists(v):
|
17
|
+
return v is not None
|
18
|
+
|
19
|
+
def default(*args):
|
20
|
+
for arg in args:
|
21
|
+
if exists(arg):
|
22
|
+
return arg
|
23
|
+
return None
|
24
|
+
|
13
25
|
def pad_at_dim(t, pad, dim = -1, value = 0.):
|
14
26
|
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
15
27
|
zeros = ((0, 0) * dims_from_right)
|
16
28
|
return F.pad(t, (*zeros, *pad), value = value)
|
17
29
|
|
30
|
+
def pack_one_with_inverse(t, pattern):
|
31
|
+
packed, packed_shape = pack([t], pattern)
|
32
|
+
|
33
|
+
def inverse(out, inv_pattern = None):
|
34
|
+
inv_pattern = default(inv_pattern, pattern)
|
35
|
+
return unpack(out, packed_shape, inv_pattern)[0]
|
36
|
+
|
37
|
+
return packed, inverse
|
38
|
+
|
18
39
|
# the operator that is needed
|
19
40
|
|
20
41
|
@torch.jit.script
|
@@ -88,3 +109,69 @@ def _interleave(a, b):
|
|
88
109
|
interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
|
89
110
|
|
90
111
|
return interleaved[:, :output_axis_len]
|
112
|
+
|
113
|
+
# associative scan wrapper around naive and accelerated version
|
114
|
+
|
115
|
+
class AssocScan(Module):
|
116
|
+
def __init__(
|
117
|
+
self,
|
118
|
+
use_accelerated = False
|
119
|
+
):
|
120
|
+
super().__init__()
|
121
|
+
self.use_accelerated = use_accelerated
|
122
|
+
|
123
|
+
def forward(
|
124
|
+
self,
|
125
|
+
gates,
|
126
|
+
inputs,
|
127
|
+
prev = None,
|
128
|
+
remove_prev = None
|
129
|
+
):
|
130
|
+
remove_prev = default(remove_prev, exists(prev))
|
131
|
+
|
132
|
+
inputs, inverse_pack_weight_shape = pack_one_with_inverse(inputs, 'b n *')
|
133
|
+
gates, _ = pack_one_with_inverse(gates, 'b n *')
|
134
|
+
|
135
|
+
if exists(prev):
|
136
|
+
prev, _ = pack_one_with_inverse(prev, 'b *')
|
137
|
+
|
138
|
+
if exists(prev):
|
139
|
+
inputs, _ = pack([prev, inputs], 'b * d')
|
140
|
+
gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
|
141
|
+
|
142
|
+
if not self.use_accelerated:
|
143
|
+
_, out = associative_scan(binary_operator, (gates, inputs))
|
144
|
+
|
145
|
+
if remove_prev:
|
146
|
+
out = out[:, 1:]
|
147
|
+
|
148
|
+
return inverse_pack_weight_shape(out)
|
149
|
+
|
150
|
+
from accelerated_scan.triton import scan as triton_scan
|
151
|
+
from accelerated_scan.warp import scan as warp_scan
|
152
|
+
|
153
|
+
scan = triton_scan if gates.is_cuda else warp_scan
|
154
|
+
|
155
|
+
def accelerate_scan_fn(gates, inputs):
|
156
|
+
gates = gates.expand_as(inputs)
|
157
|
+
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
158
|
+
|
159
|
+
seq_len = gates.shape[-1]
|
160
|
+
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
161
|
+
|
162
|
+
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
163
|
+
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
164
|
+
|
165
|
+
outputs = scan(gates.contiguous(), inputs.contiguous())
|
166
|
+
|
167
|
+
outputs = outputs[..., :seq_len]
|
168
|
+
outputs = rearrange(outputs, 'b d n -> b n d')
|
169
|
+
|
170
|
+
return outputs
|
171
|
+
|
172
|
+
out = accelerate_scan_fn(gates, inputs)
|
173
|
+
|
174
|
+
if remove_prev:
|
175
|
+
out = out[:, 1:]
|
176
|
+
|
177
|
+
return inverse_pack_weight_shape(out)
|
@@ -30,6 +30,25 @@ class LayerNorm(Module):
|
|
30
30
|
|
31
31
|
return self.ln(x) * (gamma + 1.)
|
32
32
|
|
33
|
+
# norm + residual wrapper, as used in original TTT paper
|
34
|
+
# but could be removed
|
35
|
+
|
36
|
+
class ResidualNorm(Module):
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
dim,
|
40
|
+
model: Module
|
41
|
+
):
|
42
|
+
super().__init__()
|
43
|
+
self.norm = LayerNorm(dim)
|
44
|
+
self.model = model
|
45
|
+
|
46
|
+
def forward(self, x):
|
47
|
+
|
48
|
+
out = self.model(x)
|
49
|
+
|
50
|
+
return self.norm(out) + x
|
51
|
+
|
33
52
|
# memory mlp proposed in TTT
|
34
53
|
|
35
54
|
class MemoryMLP(Module):
|
@@ -45,8 +64,6 @@ class MemoryMLP(Module):
|
|
45
64
|
|
46
65
|
self.weights = ParameterList([Parameter(torch.randn(dim_in, dim_out)) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
|
47
66
|
|
48
|
-
self.ln = LayerNorm(dim)
|
49
|
-
|
50
67
|
for weight in self.weights:
|
51
68
|
nn.init.xavier_uniform_(weight)
|
52
69
|
|
@@ -54,8 +71,6 @@ class MemoryMLP(Module):
|
|
54
71
|
self,
|
55
72
|
x
|
56
73
|
):
|
57
|
-
residual = x
|
58
|
-
|
59
74
|
for ind, weight in enumerate(self.weights):
|
60
75
|
is_first = ind == 0
|
61
76
|
|
@@ -64,7 +79,7 @@ class MemoryMLP(Module):
|
|
64
79
|
|
65
80
|
x = x @ weight
|
66
81
|
|
67
|
-
return
|
82
|
+
return x
|
68
83
|
|
69
84
|
# memory mlp, but with gated residual + final projection
|
70
85
|
|
@@ -97,7 +112,6 @@ class GatedResidualMemoryMLP(Module):
|
|
97
112
|
self,
|
98
113
|
x
|
99
114
|
):
|
100
|
-
residual = x
|
101
115
|
|
102
116
|
for weight1, weight2, to_gates in self.weights:
|
103
117
|
res = x
|
@@ -111,9 +125,7 @@ class GatedResidualMemoryMLP(Module):
|
|
111
125
|
gates = cat((branch_out, res), dim = -1) @ to_gates
|
112
126
|
x = res.lerp(branch_out, gates.sigmoid())
|
113
127
|
|
114
|
-
|
115
|
-
|
116
|
-
return self.ln(out) + residual
|
128
|
+
return x @ self.final_proj
|
117
129
|
|
118
130
|
# memory mlp with factorized weights
|
119
131
|
# so can tradeoff capacity for smaller chunk sizes
|
@@ -143,7 +155,6 @@ class FactorizedMemoryMLP(Module):
|
|
143
155
|
self,
|
144
156
|
x
|
145
157
|
):
|
146
|
-
residual = x
|
147
158
|
|
148
159
|
for ind, (weight1, weight2) in enumerate(self.weights):
|
149
160
|
is_first = ind == 0
|
@@ -153,7 +164,7 @@ class FactorizedMemoryMLP(Module):
|
|
153
164
|
|
154
165
|
x = x @ weight1 @ weight2
|
155
166
|
|
156
|
-
return
|
167
|
+
return x
|
157
168
|
|
158
169
|
# improvised attention as memory module
|
159
170
|
|
@@ -182,7 +193,6 @@ class MemoryAttention(Module):
|
|
182
193
|
nn.init.xavier_uniform_(weight)
|
183
194
|
|
184
195
|
def forward(self, x):
|
185
|
-
residual = x
|
186
196
|
|
187
197
|
wq, wk, wv, ffw1, ffw2 = self.weights
|
188
198
|
|
@@ -202,4 +212,4 @@ class MemoryAttention(Module):
|
|
202
212
|
h = F.gelu(x @ ffw1)
|
203
213
|
ff_out = h @ ffw2
|
204
214
|
|
205
|
-
return
|
215
|
+
return attn_out + ff_out
|
@@ -8,19 +8,16 @@ from collections import namedtuple
|
|
8
8
|
import torch
|
9
9
|
from torch import nn, cat, tensor, Tensor
|
10
10
|
import torch.nn.functional as F
|
11
|
-
from torch.nn import Linear, Module, Parameter, ParameterList
|
11
|
+
from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict
|
12
12
|
from torch.func import functional_call, vmap, grad
|
13
13
|
|
14
14
|
from tensordict import TensorDict
|
15
15
|
|
16
|
-
from titans_pytorch.associative_scan import
|
17
|
-
associative_scan,
|
18
|
-
binary_operator,
|
19
|
-
pad_at_dim
|
20
|
-
)
|
16
|
+
from titans_pytorch.associative_scan import AssocScan
|
21
17
|
|
22
18
|
from titans_pytorch.memory_models import(
|
23
|
-
MemoryMLP
|
19
|
+
MemoryMLP,
|
20
|
+
ResidualNorm
|
24
21
|
)
|
25
22
|
|
26
23
|
import einx
|
@@ -79,8 +76,8 @@ def safe_cat(inputs, dim = -2):
|
|
79
76
|
def is_empty_tensor(t):
|
80
77
|
return t.numel() == 0
|
81
78
|
|
82
|
-
def
|
83
|
-
return
|
79
|
+
def dict_get_value_shapes(td):
|
80
|
+
return [v.shape for k, v in td.items()]
|
84
81
|
|
85
82
|
def rearrange_dict_values(td, pattern, **kwargs):
|
86
83
|
return td.apply(lambda t: rearrange(t, pattern, **kwargs))
|
@@ -97,6 +94,11 @@ def round_down_multiple(seq, mult):
|
|
97
94
|
def round_up_multiple(seq, mult):
|
98
95
|
return math.ceil(seq / mult) * mult
|
99
96
|
|
97
|
+
def pad_at_dim(t, pad, dim = -1, value = 0.):
|
98
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
99
|
+
zeros = ((0, 0) * dims_from_right)
|
100
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
101
|
+
|
100
102
|
def pack_one_with_inverse(t, pattern):
|
101
103
|
packed, packed_shape = pack([t], pattern)
|
102
104
|
|
@@ -197,72 +199,6 @@ class AttentionPool(Module):
|
|
197
199
|
|
198
200
|
return reduce(x * attn, 'b n c d -> b n d', 'sum')
|
199
201
|
|
200
|
-
# associative scan wrapper
|
201
|
-
|
202
|
-
class AssocScan(Module):
|
203
|
-
def __init__(
|
204
|
-
self,
|
205
|
-
use_accelerated = False
|
206
|
-
):
|
207
|
-
super().__init__()
|
208
|
-
self.use_accelerated = use_accelerated
|
209
|
-
|
210
|
-
def forward(
|
211
|
-
self,
|
212
|
-
gates,
|
213
|
-
inputs,
|
214
|
-
prev = None,
|
215
|
-
remove_prev = None
|
216
|
-
):
|
217
|
-
remove_prev = default(remove_prev, exists(prev))
|
218
|
-
|
219
|
-
inputs, inverse_pack_weight_shape = pack_one_with_inverse(inputs, 'b n *')
|
220
|
-
gates, _ = pack_one_with_inverse(gates, 'b n *')
|
221
|
-
|
222
|
-
if exists(prev):
|
223
|
-
prev, _ = pack_one_with_inverse(prev, 'b *')
|
224
|
-
|
225
|
-
if exists(prev):
|
226
|
-
inputs, _ = pack([prev, inputs], 'b * d')
|
227
|
-
gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
|
228
|
-
|
229
|
-
if not self.use_accelerated:
|
230
|
-
_, out = associative_scan(binary_operator, (gates, inputs))
|
231
|
-
|
232
|
-
if remove_prev:
|
233
|
-
out = out[:, 1:]
|
234
|
-
|
235
|
-
return inverse_pack_weight_shape(out)
|
236
|
-
|
237
|
-
from accelerated_scan.triton import scan as triton_scan
|
238
|
-
from accelerated_scan.warp import scan as warp_scan
|
239
|
-
|
240
|
-
scan = triton_scan if gates.is_cuda else warp_scan
|
241
|
-
|
242
|
-
def accelerate_scan_fn(gates, inputs):
|
243
|
-
gates = gates.expand_as(inputs)
|
244
|
-
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
245
|
-
|
246
|
-
seq_len = gates.shape[-1]
|
247
|
-
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
248
|
-
|
249
|
-
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
250
|
-
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
251
|
-
|
252
|
-
outputs = scan(gates.contiguous(), inputs.contiguous())
|
253
|
-
|
254
|
-
outputs = outputs[..., :seq_len]
|
255
|
-
outputs = rearrange(outputs, 'b d n -> b n d')
|
256
|
-
|
257
|
-
return outputs
|
258
|
-
|
259
|
-
out = accelerate_scan_fn(gates, inputs)
|
260
|
-
|
261
|
-
if remove_prev:
|
262
|
-
out = out[:, 1:]
|
263
|
-
|
264
|
-
return inverse_pack_weight_shape(out)
|
265
|
-
|
266
202
|
# main neural memory
|
267
203
|
|
268
204
|
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
|
@@ -285,6 +221,7 @@ class NeuralMemory(Module):
|
|
285
221
|
default_step_transform_max_lr = 1.,
|
286
222
|
per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
|
287
223
|
max_mem_layer_modulation = 1., # max of 10.
|
224
|
+
per_head_learned_parameters = True,
|
288
225
|
attn_pool_chunks = False,
|
289
226
|
momentum = True,
|
290
227
|
pre_rmsnorm = True,
|
@@ -298,6 +235,7 @@ class NeuralMemory(Module):
|
|
298
235
|
init_decay_bias = None,
|
299
236
|
accept_weight_residual = False,
|
300
237
|
gated_transition = False,
|
238
|
+
mem_model_norm_add_residual = True, # by default, layernorm output and add residual as proposed in TTT paper, but could be removed
|
301
239
|
default_model_kwargs: dict = dict(
|
302
240
|
depth = 2,
|
303
241
|
expansion_factor = 4.
|
@@ -368,11 +306,26 @@ class NeuralMemory(Module):
|
|
368
306
|
|
369
307
|
# the memory is the weights of the model
|
370
308
|
|
309
|
+
if mem_model_norm_add_residual:
|
310
|
+
model = ResidualNorm(dim = dim_head, model = model)
|
311
|
+
|
371
312
|
self.memory_model = model
|
372
313
|
|
373
|
-
|
314
|
+
mem_model_params = dict(model.named_parameters())
|
315
|
+
|
316
|
+
self.num_memory_parameter_tensors = len(mem_model_params)
|
374
317
|
|
375
|
-
self.
|
318
|
+
self.memory_model_parameter_names = [*mem_model_params.keys()]
|
319
|
+
|
320
|
+
memory_model_parameters = [*mem_model_params.values()]
|
321
|
+
|
322
|
+
if per_head_learned_parameters:
|
323
|
+
memory_model_parameters = [repeat(p, '... -> h ...', h = heads) for p in memory_model_parameters]
|
324
|
+
|
325
|
+
self.init_weight_shape = [p.shape for p in memory_model_parameters]
|
326
|
+
|
327
|
+
self.memory_model_parameters = ParameterList(memory_model_parameters)
|
328
|
+
self.per_head_learned_parameters = per_head_learned_parameters
|
376
329
|
|
377
330
|
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
378
331
|
|
@@ -488,21 +441,32 @@ class NeuralMemory(Module):
|
|
488
441
|
|
489
442
|
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
490
443
|
|
444
|
+
@property
|
445
|
+
def memory_model_parameter_dict(self):
|
446
|
+
return TensorDict(dict(zip(self.memory_model_parameter_names, self.memory_model_parameters)))
|
447
|
+
|
491
448
|
def init_weights(
|
492
449
|
self,
|
493
450
|
batch,
|
494
451
|
):
|
495
|
-
|
496
|
-
|
452
|
+
if self.per_head_learned_parameters:
|
453
|
+
weights = repeat_dict_values(self.memory_model_parameter_dict, 'h ... -> (b h) ...', b = batch)
|
454
|
+
else:
|
455
|
+
weights = repeat_dict_values(self.memory_model_parameter_dict, '... -> bh ...', bh = batch * self.heads)
|
456
|
+
|
497
457
|
return weights
|
498
458
|
|
499
459
|
def init_momentum(
|
500
460
|
self,
|
501
461
|
batch,
|
502
462
|
):
|
503
|
-
|
504
|
-
|
505
|
-
|
463
|
+
zeros = self.memory_model_parameter_dict.clone().zero_()
|
464
|
+
|
465
|
+
if self.per_head_learned_parameters:
|
466
|
+
zeros = repeat_dict_values(zeros, 'h ... -> (b h) ...', b = batch)
|
467
|
+
else:
|
468
|
+
zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
|
469
|
+
|
506
470
|
return zeros
|
507
471
|
|
508
472
|
def store_memories(
|
@@ -694,7 +658,7 @@ class NeuralMemory(Module):
|
|
694
658
|
):
|
695
659
|
chunk_size = self.retrieve_chunk_size
|
696
660
|
|
697
|
-
weights_have_expanded_shape =
|
661
|
+
weights_have_expanded_shape = dict_get_value_shapes(weights) != self.init_weight_shape
|
698
662
|
|
699
663
|
batch, seq_len = seq.shape[:2]
|
700
664
|
|
@@ -10,7 +10,11 @@ from torch.utils.data import DataLoader, Dataset
|
|
10
10
|
|
11
11
|
from adam_atan2_pytorch import AdoptAtan2
|
12
12
|
|
13
|
-
from titans_pytorch import
|
13
|
+
from titans_pytorch import (
|
14
|
+
MemoryAsContextTransformer,
|
15
|
+
MemoryMLP,
|
16
|
+
MemoryAttention
|
17
|
+
)
|
14
18
|
|
15
19
|
# constants
|
16
20
|
|
@@ -35,6 +39,7 @@ NEURAL_MEM_GATE_ATTN_OUTPUT = False
|
|
35
39
|
NEURAL_MEM_MOMENTUM = True
|
36
40
|
NEURAL_MEM_QK_NORM = True
|
37
41
|
NEURAL_MEM_MAX_LR = 1e-1
|
42
|
+
USE_MEM_ATTENTION_MODEL = False
|
38
43
|
WINDOW_SIZE = 32
|
39
44
|
NEURAL_MEM_SEGMENT_LEN = 4 # set smaller for more granularity for learning rate / momentum etc
|
40
45
|
NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence
|
@@ -75,6 +80,18 @@ def decode_token(token):
|
|
75
80
|
def decode_tokens(tokens):
|
76
81
|
return ''.join(list(map(decode_token, tokens)))
|
77
82
|
|
83
|
+
# memory model
|
84
|
+
|
85
|
+
if USE_MEM_ATTENTION_MODEL:
|
86
|
+
neural_memory_model = MemoryAttention(
|
87
|
+
dim = 64
|
88
|
+
)
|
89
|
+
else:
|
90
|
+
neural_memory_model = MemoryMLP(
|
91
|
+
dim = 64,
|
92
|
+
depth = NEURAL_MEMORY_DEPTH
|
93
|
+
)
|
94
|
+
|
78
95
|
# instantiate memory-as-context transformer
|
79
96
|
|
80
97
|
model = MemoryAsContextTransformer(
|
@@ -91,10 +108,7 @@ model = MemoryAsContextTransformer(
|
|
91
108
|
neural_mem_weight_residual = NEURAL_MEM_WEIGHT_RESIDUAL,
|
92
109
|
use_flex_attn = USE_FLEX_ATTN,
|
93
110
|
sliding_window_attn = SLIDING_WINDOWS,
|
94
|
-
neural_memory_model =
|
95
|
-
dim = 64,
|
96
|
-
depth = NEURAL_MEMORY_DEPTH
|
97
|
-
),
|
111
|
+
neural_memory_model = neural_memory_model,
|
98
112
|
neural_memory_kwargs = dict(
|
99
113
|
dim_head = 64,
|
100
114
|
heads = 4,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|