titans-pytorch 0.3.2__tar.gz → 0.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.2
3
+ Version: 0.3.4
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.3.2"
3
+ version = "0.3.4"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -31,6 +31,7 @@ def torch_default_dtype(dtype):
31
31
  @pytest.mark.parametrize('qk_rmsnorm', (False, True))
32
32
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
33
33
  @pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
34
+ @pytest.mark.parametrize('per_head_learned_parameters', (False, True))
34
35
  def test_titans(
35
36
  seq_len,
36
37
  silu,
@@ -39,10 +40,11 @@ def test_titans(
39
40
  momentum,
40
41
  qk_rmsnorm,
41
42
  max_grad_norm,
42
- per_parameter_lr_modulation
43
+ per_parameter_lr_modulation,
44
+ per_head_learned_parameters
43
45
  ):
44
46
  mem = NeuralMemory(
45
- dim = 384,
47
+ dim = 16,
46
48
  chunk_size = chunk_size,
47
49
  activation = nn.SiLU() if silu else None,
48
50
  attn_pool_chunks = attn_pool_chunks,
@@ -50,9 +52,10 @@ def test_titans(
50
52
  momentum = momentum,
51
53
  qk_rmsnorm = qk_rmsnorm,
52
54
  per_parameter_lr_modulation = per_parameter_lr_modulation,
55
+ per_head_learned_parameters = per_head_learned_parameters
53
56
  )
54
57
 
55
- seq = torch.randn(2, seq_len, 384)
58
+ seq = torch.randn(2, seq_len, 16)
56
59
  retrieved, _ = mem(seq)
57
60
 
58
61
  assert seq.shape == retrieved.shape
@@ -61,14 +64,14 @@ def test_titans_attn_memory():
61
64
  from titans_pytorch.memory_models import MemoryAttention
62
65
 
63
66
  mem = NeuralMemory(
64
- dim = 384,
67
+ dim = 16,
65
68
  chunk_size = 64,
66
69
  model = MemoryAttention(
67
- dim = 384
70
+ dim = 16
68
71
  )
69
72
  )
70
73
 
71
- seq = torch.randn(2, 1024, 384)
74
+ seq = torch.randn(2, 1024, 16)
72
75
  retrieved, _ = mem(seq)
73
76
 
74
77
  assert seq.shape == retrieved.shape
@@ -78,14 +81,14 @@ def test_neural_mem_chaining_chunks(
78
81
  gated_transition
79
82
  ):
80
83
  mem = NeuralMemory(
81
- dim = 384,
82
- dim_head = 64,
84
+ dim = 16,
85
+ dim_head = 16,
83
86
  heads = 2,
84
87
  chunk_size = 16,
85
88
  gated_transition = gated_transition
86
89
  )
87
90
 
88
- seq = torch.randn(2, 48, 384)
91
+ seq = torch.randn(2, 48, 16)
89
92
 
90
93
  parallel_retrieved, state = mem(seq)
91
94
 
@@ -99,21 +102,21 @@ def test_neural_mem_chaining_chunks(
99
102
 
100
103
  def test_neural_mem_chaining_with_weight_residual():
101
104
  mem = NeuralMemory(
102
- dim = 384,
103
- dim_head = 64,
105
+ dim = 16,
106
+ dim_head = 16,
104
107
  heads = 2,
105
108
  chunk_size = 64
106
109
  )
107
110
 
108
111
  mem2 = NeuralMemory(
109
- dim = 384,
110
- dim_head = 64,
112
+ dim = 16,
113
+ dim_head = 16,
111
114
  heads = 2,
112
115
  chunk_size = 64,
113
116
  accept_weight_residual = True
114
117
  )
115
118
 
116
- seq = torch.randn(2, 256, 384)
119
+ seq = torch.randn(2, 256, 16)
117
120
 
118
121
  seq, state = mem(seq)
119
122
 
@@ -124,18 +127,18 @@ def test_neural_mem_chaining_with_weight_residual():
124
127
  first_retrieved, state1 = mem2(seq_first, prev_weights = state.updates)
125
128
  second_retrieved, state2 = mem2(seq_second, state = state1, prev_weights = state.updates)
126
129
 
127
- assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-6)
130
+ assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-5)
128
131
 
129
132
  def test_neural_mem_chaining_with_batch_size():
130
133
  mem = NeuralMemory(
131
- dim = 384,
132
- dim_head = 64,
134
+ dim = 16,
135
+ dim_head = 16,
133
136
  heads = 2,
134
137
  chunk_size = 16,
135
138
  batch_size = 64
136
139
  )
137
140
 
138
- seq = torch.randn(2, 112, 384)
141
+ seq = torch.randn(2, 112, 16)
139
142
 
140
143
  parallel_retrieved, state = mem(seq)
141
144
 
@@ -169,7 +172,7 @@ def test_mac(
169
172
  ):
170
173
  transformer = MemoryAsContextTransformer(
171
174
  num_tokens = 256,
172
- dim = 256,
175
+ dim = 16,
173
176
  depth = 2,
174
177
  num_persist_mem_tokens = num_persist_mem_tokens,
175
178
  num_longterm_mem_tokens = num_longterm_mem_tokens,
@@ -201,7 +204,7 @@ def test_mac_sampling(
201
204
  ):
202
205
  transformer = MemoryAsContextTransformer(
203
206
  num_tokens = 256,
204
- dim = 256,
207
+ dim = 16,
205
208
  depth = 4,
206
209
  segment_len = 32,
207
210
  num_persist_mem_tokens = 4,
@@ -235,12 +238,12 @@ def test_neural_mem_inference(
235
238
  ):
236
239
 
237
240
  mem = NeuralMemory(
238
- dim = 384,
241
+ dim = 16,
239
242
  chunk_size = mem_chunk_size,
240
243
  gated_transition = gated_transition
241
244
  )
242
245
 
243
- seq = torch.randn(2, seq_len, 384)
246
+ seq = torch.randn(2, seq_len, 16)
244
247
  parallel_retrieved, _ = mem(seq)
245
248
 
246
249
  assert seq.shape == parallel_retrieved.shape
@@ -282,7 +285,7 @@ def test_flex(
282
285
  pytest.skip()
283
286
 
284
287
  attn = SegmentedAttention(
285
- dim = 512,
288
+ dim = 16,
286
289
  segment_len = 32,
287
290
  num_persist_mem_tokens = 1,
288
291
  num_longterm_mem_tokens = 1,
@@ -290,7 +293,7 @@ def test_flex(
290
293
  sliding = sliding
291
294
  ).cuda()
292
295
 
293
- seq = torch.randn(1, seq_len, 512).cuda()
296
+ seq = torch.randn(1, seq_len, 16).cuda()
294
297
 
295
298
  out_flex, _ = attn(seq)
296
299
  out_non_flex, _ = attn(seq, disable_flex_attn = True)
@@ -307,8 +310,8 @@ def test_assoc_scan():
307
310
  seq_len = 128
308
311
  mid_point = seq_len // 2
309
312
 
310
- gates = torch.randn(2, seq_len, 512).sigmoid()
311
- inputs = torch.randn(2, seq_len, 512)
313
+ gates = torch.randn(2, seq_len, 16).sigmoid()
314
+ inputs = torch.randn(2, seq_len, 16)
312
315
 
313
316
  output = scan(gates, inputs)
314
317
 
@@ -3,18 +3,39 @@ from typing import Callable
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
6
+ from torch.nn import Module
6
7
  import torch.nn.functional as F
7
8
 
9
+ from einops import rearrange, repeat, reduce, pack, unpack
10
+
8
11
  # taken from S5-pytorch repository
9
12
  # https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
10
13
 
11
14
  # helper functions
12
15
 
16
+ def exists(v):
17
+ return v is not None
18
+
19
+ def default(*args):
20
+ for arg in args:
21
+ if exists(arg):
22
+ return arg
23
+ return None
24
+
13
25
  def pad_at_dim(t, pad, dim = -1, value = 0.):
14
26
  dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
15
27
  zeros = ((0, 0) * dims_from_right)
16
28
  return F.pad(t, (*zeros, *pad), value = value)
17
29
 
30
+ def pack_one_with_inverse(t, pattern):
31
+ packed, packed_shape = pack([t], pattern)
32
+
33
+ def inverse(out, inv_pattern = None):
34
+ inv_pattern = default(inv_pattern, pattern)
35
+ return unpack(out, packed_shape, inv_pattern)[0]
36
+
37
+ return packed, inverse
38
+
18
39
  # the operator that is needed
19
40
 
20
41
  @torch.jit.script
@@ -88,3 +109,69 @@ def _interleave(a, b):
88
109
  interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
89
110
 
90
111
  return interleaved[:, :output_axis_len]
112
+
113
+ # associative scan wrapper around naive and accelerated version
114
+
115
+ class AssocScan(Module):
116
+ def __init__(
117
+ self,
118
+ use_accelerated = False
119
+ ):
120
+ super().__init__()
121
+ self.use_accelerated = use_accelerated
122
+
123
+ def forward(
124
+ self,
125
+ gates,
126
+ inputs,
127
+ prev = None,
128
+ remove_prev = None
129
+ ):
130
+ remove_prev = default(remove_prev, exists(prev))
131
+
132
+ inputs, inverse_pack_weight_shape = pack_one_with_inverse(inputs, 'b n *')
133
+ gates, _ = pack_one_with_inverse(gates, 'b n *')
134
+
135
+ if exists(prev):
136
+ prev, _ = pack_one_with_inverse(prev, 'b *')
137
+
138
+ if exists(prev):
139
+ inputs, _ = pack([prev, inputs], 'b * d')
140
+ gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
141
+
142
+ if not self.use_accelerated:
143
+ _, out = associative_scan(binary_operator, (gates, inputs))
144
+
145
+ if remove_prev:
146
+ out = out[:, 1:]
147
+
148
+ return inverse_pack_weight_shape(out)
149
+
150
+ from accelerated_scan.triton import scan as triton_scan
151
+ from accelerated_scan.warp import scan as warp_scan
152
+
153
+ scan = triton_scan if gates.is_cuda else warp_scan
154
+
155
+ def accelerate_scan_fn(gates, inputs):
156
+ gates = gates.expand_as(inputs)
157
+ gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
158
+
159
+ seq_len = gates.shape[-1]
160
+ next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
161
+
162
+ gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
163
+ inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
164
+
165
+ outputs = scan(gates.contiguous(), inputs.contiguous())
166
+
167
+ outputs = outputs[..., :seq_len]
168
+ outputs = rearrange(outputs, 'b d n -> b n d')
169
+
170
+ return outputs
171
+
172
+ out = accelerate_scan_fn(gates, inputs)
173
+
174
+ if remove_prev:
175
+ out = out[:, 1:]
176
+
177
+ return inverse_pack_weight_shape(out)
@@ -37,7 +37,7 @@ class MemoryMLP(Module):
37
37
  self,
38
38
  dim,
39
39
  depth,
40
- expansion_factor = 4.
40
+ expansion_factor = 2.
41
41
  ):
42
42
  super().__init__()
43
43
  dim_hidden = int(dim * expansion_factor)
@@ -8,16 +8,12 @@ from collections import namedtuple
8
8
  import torch
9
9
  from torch import nn, cat, tensor, Tensor
10
10
  import torch.nn.functional as F
11
- from torch.nn import Linear, Module, Parameter, ParameterList
11
+ from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict
12
12
  from torch.func import functional_call, vmap, grad
13
13
 
14
14
  from tensordict import TensorDict
15
15
 
16
- from titans_pytorch.associative_scan import (
17
- associative_scan,
18
- binary_operator,
19
- pad_at_dim
20
- )
16
+ from titans_pytorch.associative_scan import AssocScan
21
17
 
22
18
  from titans_pytorch.memory_models import(
23
19
  MemoryMLP
@@ -79,8 +75,8 @@ def safe_cat(inputs, dim = -2):
79
75
  def is_empty_tensor(t):
80
76
  return t.numel() == 0
81
77
 
82
- def dict_get_shape(td):
83
- return {k: v.shape for k, v in td.items()}
78
+ def dict_get_value_shapes(td):
79
+ return [v.shape for k, v in td.items()]
84
80
 
85
81
  def rearrange_dict_values(td, pattern, **kwargs):
86
82
  return td.apply(lambda t: rearrange(t, pattern, **kwargs))
@@ -97,6 +93,11 @@ def round_down_multiple(seq, mult):
97
93
  def round_up_multiple(seq, mult):
98
94
  return math.ceil(seq / mult) * mult
99
95
 
96
+ def pad_at_dim(t, pad, dim = -1, value = 0.):
97
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
98
+ zeros = ((0, 0) * dims_from_right)
99
+ return F.pad(t, (*zeros, *pad), value = value)
100
+
100
101
  def pack_one_with_inverse(t, pattern):
101
102
  packed, packed_shape = pack([t], pattern)
102
103
 
@@ -197,72 +198,6 @@ class AttentionPool(Module):
197
198
 
198
199
  return reduce(x * attn, 'b n c d -> b n d', 'sum')
199
200
 
200
- # associative scan wrapper
201
-
202
- class AssocScan(Module):
203
- def __init__(
204
- self,
205
- use_accelerated = False
206
- ):
207
- super().__init__()
208
- self.use_accelerated = use_accelerated
209
-
210
- def forward(
211
- self,
212
- gates,
213
- inputs,
214
- prev = None,
215
- remove_prev = None
216
- ):
217
- remove_prev = default(remove_prev, exists(prev))
218
-
219
- inputs, inverse_pack_weight_shape = pack_one_with_inverse(inputs, 'b n *')
220
- gates, _ = pack_one_with_inverse(gates, 'b n *')
221
-
222
- if exists(prev):
223
- prev, _ = pack_one_with_inverse(prev, 'b *')
224
-
225
- if exists(prev):
226
- inputs, _ = pack([prev, inputs], 'b * d')
227
- gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
228
-
229
- if not self.use_accelerated:
230
- _, out = associative_scan(binary_operator, (gates, inputs))
231
-
232
- if remove_prev:
233
- out = out[:, 1:]
234
-
235
- return inverse_pack_weight_shape(out)
236
-
237
- from accelerated_scan.triton import scan as triton_scan
238
- from accelerated_scan.warp import scan as warp_scan
239
-
240
- scan = triton_scan if gates.is_cuda else warp_scan
241
-
242
- def accelerate_scan_fn(gates, inputs):
243
- gates = gates.expand_as(inputs)
244
- gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
245
-
246
- seq_len = gates.shape[-1]
247
- next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
248
-
249
- gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
250
- inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
251
-
252
- outputs = scan(gates.contiguous(), inputs.contiguous())
253
-
254
- outputs = outputs[..., :seq_len]
255
- outputs = rearrange(outputs, 'b d n -> b n d')
256
-
257
- return outputs
258
-
259
- out = accelerate_scan_fn(gates, inputs)
260
-
261
- if remove_prev:
262
- out = out[:, 1:]
263
-
264
- return inverse_pack_weight_shape(out)
265
-
266
201
  # main neural memory
267
202
 
268
203
  def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
@@ -285,6 +220,7 @@ class NeuralMemory(Module):
285
220
  default_step_transform_max_lr = 1.,
286
221
  per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
287
222
  max_mem_layer_modulation = 1., # max of 10.
223
+ per_head_learned_parameters = True,
288
224
  attn_pool_chunks = False,
289
225
  momentum = True,
290
226
  pre_rmsnorm = True,
@@ -370,9 +306,21 @@ class NeuralMemory(Module):
370
306
 
371
307
  self.memory_model = model
372
308
 
373
- self.num_memory_parameter_tensors = len(set(model.parameters()))
309
+ mem_model_params = dict(model.named_parameters())
310
+
311
+ self.num_memory_parameter_tensors = len(mem_model_params)
312
+
313
+ self.memory_model_parameter_names = [*mem_model_params.keys()]
314
+
315
+ memory_model_parameters = [*mem_model_params.values()]
316
+
317
+ if per_head_learned_parameters:
318
+ memory_model_parameters = [repeat(p, '... -> h ...', h = heads) for p in memory_model_parameters]
319
+
320
+ self.init_weight_shape = [p.shape for p in memory_model_parameters]
374
321
 
375
- self.init_weight_shape = dict_get_shape(dict(model.named_parameters()))
322
+ self.memory_model_parameters = ParameterList(memory_model_parameters)
323
+ self.per_head_learned_parameters = per_head_learned_parameters
376
324
 
377
325
  # the chunk size within the paper where adaptive step, momentum, weight decay are shared
378
326
 
@@ -488,21 +436,32 @@ class NeuralMemory(Module):
488
436
 
489
437
  self.register_buffer('zero', torch.tensor(0.), persistent = False)
490
438
 
439
+ @property
440
+ def memory_model_parameter_dict(self):
441
+ return TensorDict(dict(zip(self.memory_model_parameter_names, self.memory_model_parameters)))
442
+
491
443
  def init_weights(
492
444
  self,
493
445
  batch,
494
446
  ):
495
- weights = TensorDict(dict(self.memory_model.named_parameters()))
496
- weights = repeat_dict_values(weights, '... -> bh ...', bh = batch * self.heads)
447
+ if self.per_head_learned_parameters:
448
+ weights = repeat_dict_values(self.memory_model_parameter_dict, 'h ... -> (b h) ...', b = batch)
449
+ else:
450
+ weights = repeat_dict_values(self.memory_model_parameter_dict, '... -> bh ...', bh = batch * self.heads)
451
+
497
452
  return weights
498
453
 
499
454
  def init_momentum(
500
455
  self,
501
456
  batch,
502
457
  ):
503
- weights = TensorDict(dict(self.memory_model.named_parameters()))
504
- zeros = weights.clone().zero_()
505
- zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
458
+ zeros = self.memory_model_parameter_dict.clone().zero_()
459
+
460
+ if self.per_head_learned_parameters:
461
+ zeros = repeat_dict_values(zeros, 'h ... -> (b h) ...', b = batch)
462
+ else:
463
+ zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
464
+
506
465
  return zeros
507
466
 
508
467
  def store_memories(
@@ -690,16 +649,27 @@ class NeuralMemory(Module):
690
649
  def retrieve_memories(
691
650
  self,
692
651
  seq,
693
- past_weights: dict[str, Tensor],
694
- chunk_size = None,
695
- need_pad = True
652
+ weights: dict[str, Tensor],
696
653
  ):
697
- chunk_size = default(chunk_size, self.retrieve_chunk_size)
654
+ chunk_size = self.retrieve_chunk_size
655
+
656
+ weights_have_expanded_shape = dict_get_value_shapes(weights) != self.init_weight_shape
657
+
698
658
  batch, seq_len = seq.shape[:2]
699
659
 
700
- seq = self.retrieve_norm(seq)
660
+ # auto infer single token decoding, if there are only 1 set of weights and 1 token
661
+
662
+ is_one_token = seq_len == 1
663
+ is_one_weight = (not weights_have_expanded_shape) or next(iter(weights.values())).shape[1] == 1
664
+
665
+ is_single_token_decode = is_one_token and is_one_weight
701
666
 
702
- need_pad = need_pad or chunk_size > 1
667
+ if is_single_token_decode:
668
+ chunk_size = 1
669
+
670
+ # padding related, for chunked processing
671
+
672
+ need_pad = chunk_size > 1 or not is_one_weight
703
673
 
704
674
  if need_pad:
705
675
  seq = pad_at_dim(seq, (1, 0), dim = 1)
@@ -714,7 +684,11 @@ class NeuralMemory(Module):
714
684
  # the parameters of the memory model stores the memories of the key / values
715
685
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
716
686
 
717
- curr_weights = TensorDict(past_weights)
687
+ weights = TensorDict(weights)
688
+
689
+ # pre norm
690
+
691
+ seq = self.retrieve_norm(seq)
718
692
 
719
693
  # sequence Float['b n d'] to queries
720
694
 
@@ -730,14 +704,14 @@ class NeuralMemory(Module):
730
704
 
731
705
  # fetch values from memory model
732
706
 
733
- if dict_get_shape(curr_weights) != self.init_weight_shape:
734
- curr_weights = rearrange_dict_values(curr_weights, 'b n ... -> (b n) ...')
707
+ if weights_have_expanded_shape:
708
+ weights = rearrange_dict_values(weights, 'b n ... -> (b n) ...')
735
709
 
736
710
  queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
737
711
 
738
712
  # forward functional call
739
713
 
740
- values = functional_call(self.memory_model, dict(curr_weights), queries)
714
+ values = functional_call(self.memory_model, dict(weights), queries)
741
715
 
742
716
  # reconstitute batch dimension
743
717
 
@@ -885,22 +859,13 @@ class NeuralMemory(Module):
885
859
 
886
860
  # retrieve
887
861
 
888
- need_pad = True
889
- retrieve_chunk_size = None
890
-
891
862
  if is_single_token:
892
- retrieve_chunk_size = 1
893
- need_pad = False
894
-
895
863
  last_update, _ = next_neural_mem_state.states
896
-
897
864
  updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
898
865
 
899
866
  retrieved = self.retrieve_memories(
900
867
  seq,
901
- updates,
902
- chunk_size = retrieve_chunk_size,
903
- need_pad = need_pad,
868
+ updates
904
869
  )
905
870
 
906
871
  return retrieved, next_neural_mem_state
@@ -10,7 +10,11 @@ from torch.utils.data import DataLoader, Dataset
10
10
 
11
11
  from adam_atan2_pytorch import AdoptAtan2
12
12
 
13
- from titans_pytorch import MemoryAsContextTransformer, MemoryMLP
13
+ from titans_pytorch import (
14
+ MemoryAsContextTransformer,
15
+ MemoryMLP,
16
+ MemoryAttention
17
+ )
14
18
 
15
19
  # constants
16
20
 
@@ -35,6 +39,7 @@ NEURAL_MEM_GATE_ATTN_OUTPUT = False
35
39
  NEURAL_MEM_MOMENTUM = True
36
40
  NEURAL_MEM_QK_NORM = True
37
41
  NEURAL_MEM_MAX_LR = 1e-1
42
+ USE_MEM_ATTENTION_MODEL = False
38
43
  WINDOW_SIZE = 32
39
44
  NEURAL_MEM_SEGMENT_LEN = 4 # set smaller for more granularity for learning rate / momentum etc
40
45
  NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence
@@ -75,6 +80,18 @@ def decode_token(token):
75
80
  def decode_tokens(tokens):
76
81
  return ''.join(list(map(decode_token, tokens)))
77
82
 
83
+ # memory model
84
+
85
+ if USE_MEM_ATTENTION_MODEL:
86
+ neural_memory_model = MemoryAttention(
87
+ dim = 64
88
+ )
89
+ else:
90
+ neural_memory_model = MemoryMLP(
91
+ dim = 64,
92
+ depth = NEURAL_MEMORY_DEPTH
93
+ )
94
+
78
95
  # instantiate memory-as-context transformer
79
96
 
80
97
  model = MemoryAsContextTransformer(
@@ -91,10 +108,7 @@ model = MemoryAsContextTransformer(
91
108
  neural_mem_weight_residual = NEURAL_MEM_WEIGHT_RESIDUAL,
92
109
  use_flex_attn = USE_FLEX_ATTN,
93
110
  sliding_window_attn = SLIDING_WINDOWS,
94
- neural_memory_model = MemoryMLP(
95
- dim = 64,
96
- depth = NEURAL_MEMORY_DEPTH
97
- ),
111
+ neural_memory_model = neural_memory_model,
98
112
  neural_memory_kwargs = dict(
99
113
  dim_head = 64,
100
114
  heads = 4,
File without changes
File without changes
File without changes
File without changes