titans-pytorch 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+ import math
2
3
  from functools import partial
3
4
 
4
5
  import torch
@@ -11,12 +12,13 @@ from tensordict import TensorDict
11
12
 
12
13
  from titans_pytorch.associative_scan import (
13
14
  associative_scan,
14
- binary_operator
15
+ binary_operator,
16
+ pad_at_dim
15
17
  )
16
18
 
17
19
  import einx
18
20
  from einops import rearrange, pack, unpack
19
- from einops.layers.torch import Rearrange
21
+ from einops.layers.torch import Rearrange, Reduce
20
22
 
21
23
  """
22
24
  ein notation:
@@ -41,6 +43,9 @@ def default(v, d):
41
43
  def round_down_multiple(seq, mult):
42
44
  return seq // mult * mult
43
45
 
46
+ def round_up_multiple(seq, mult):
47
+ return math.ceil(seq / mult) * mult
48
+
44
49
  def pack_one_with_inverse(t, pattern):
45
50
  packed, packed_shape = pack([t], pattern)
46
51
 
@@ -84,6 +89,7 @@ class NeuralMemory(Module):
84
89
  def __init__(
85
90
  self,
86
91
  dim,
92
+ chunk_size = 1,
87
93
  model: Module | None = None,
88
94
  store_memory_loss_fn: Callable = default_loss_fn
89
95
  ):
@@ -98,11 +104,15 @@ class NeuralMemory(Module):
98
104
 
99
105
  self.memory_model = model
100
106
 
107
+ # the chunk size within the paper where adaptive step, momentum, weight decay are shared
108
+
109
+ self.chunk_size = chunk_size
110
+
101
111
  # prepare function for per sample gradients from model above, using torch.func
102
112
 
103
113
  def forward_and_loss(params, inputs, target):
104
114
  pred = functional_call(self.memory_model, params, inputs)
105
- loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) == v|²
115
+ loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
106
116
  return loss
107
117
 
108
118
  self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
@@ -119,9 +129,23 @@ class NeuralMemory(Module):
119
129
  # learned adaptive learning rate and momentum
120
130
  # todo - explore mlp layerwise learned lr / momentum
121
131
 
122
- self.to_momentum = LinearNoBias(dim, 1)
123
- self.to_adaptive_step = nn.Sequential(LinearNoBias(dim, 1), Rearrange('... 1 -> ...'))
124
- self.to_decay_factor = LinearNoBias(dim, 1) # weight decay factor
132
+ self.to_momentum = nn.Sequential(
133
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
134
+ LinearNoBias(dim, 1)
135
+ )
136
+
137
+ self.to_adaptive_step = nn.Sequential(
138
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
139
+ LinearNoBias(dim, 1),
140
+ Rearrange('... 1 -> ...')
141
+ )
142
+
143
+ # weight decay factor
144
+
145
+ self.to_decay_factor = nn.Sequential(
146
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
147
+ LinearNoBias(dim, 1)
148
+ )
125
149
 
126
150
  def init_weights_and_momentum(self):
127
151
  params = TensorDict(dict(self.memory_model.named_parameters()))
@@ -137,6 +161,16 @@ class NeuralMemory(Module):
137
161
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
138
162
  ):
139
163
 
164
+ # curtail sequence by multiple of the chunk size
165
+ # only a complete chunk of the sequence provides the memory for the next chunk
166
+
167
+ seq_len = seq.shape[-2]
168
+ round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
169
+
170
+ seq = seq[:, :round_down_seq_len]
171
+
172
+ # curr weights + past weights, in the case that the initial weights are learned
173
+
140
174
  curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
141
175
 
142
176
  past_state = tuple(TensorDict(d) for d in past_state)
@@ -148,16 +182,19 @@ class NeuralMemory(Module):
148
182
 
149
183
  batch = seq.shape[0]
150
184
 
151
- adaptive_lr = self.to_adaptive_step(seq).tanh() * 0.5 + 0.5.
185
+ adaptive_lr = self.to_adaptive_step(seq).tanh() * 0.5 + 0.5
152
186
 
153
187
  adaptive_momentum = self.to_momentum(seq).sigmoid()
154
188
  decay_factor = self.to_decay_factor(seq).sigmoid()
155
189
 
156
190
  # keys and values
157
191
 
158
- seq = rearrange(seq, 'b n d -> (b n) d')
159
192
  keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
160
193
 
194
+ # take care of chunking
195
+
196
+ keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
197
+
161
198
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
162
199
 
163
200
  grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
@@ -172,31 +209,24 @@ class NeuralMemory(Module):
172
209
 
173
210
  surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
174
211
 
175
- # derive momentum with associative scan - eq (10)
212
+ # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
176
213
 
177
214
  next_momentum = TensorDict()
215
+ updates = TensorDict()
178
216
 
179
217
  for param_name, surprise in surprises.items():
180
218
  surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
181
219
 
182
- _, momentum = associative_scan(binary_operator, (adaptive_momentum, surprise)) # momentum is S / surprise in the paper
183
-
184
- momentum = inverse_pack(momentum)
185
-
186
- next_momentum[param_name] = momentum
220
+ # derive momentum with associative scan - eq (10)
187
221
 
188
- # use associative scan again for learned forgetting (weight decay) - eq (13)
189
-
190
- updates = TensorDict()
222
+ _, momentum = associative_scan(binary_operator, (adaptive_momentum, surprise)) # momentum is S / surprise in the paper
191
223
 
192
- for param_name, momentum in next_momentum.items():
193
- momentum, inverse_pack = pack_one_with_inverse(momentum, 'b n *')
224
+ # use associative scan again for learned forgetting (weight decay) - eq (13)
194
225
 
195
226
  _, update = associative_scan(binary_operator, (1. - decay_factor, momentum)) # momentum is S / surprise in the paper
196
227
 
197
- update = inverse_pack(update)
198
-
199
- updates[param_name] = update
228
+ updates[param_name] = inverse_pack(update)
229
+ next_momentum[param_name] = inverse_pack(momentum)
200
230
 
201
231
  # compute the next weight per batch
202
232
 
@@ -211,7 +241,19 @@ class NeuralMemory(Module):
211
241
  seq,
212
242
  past_weights: dict[str, Tensor] | None = None,
213
243
  ):
214
- batch = seq.shape[0]
244
+ chunk_size = self.chunk_size
245
+ batch, seq_len = seq.shape[:2]
246
+
247
+ assert seq_len >= chunk_size
248
+
249
+ seq = seq[:, (chunk_size - 1):]
250
+ curtailed_seq_len = seq.shape[-2]
251
+
252
+ next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
253
+
254
+ padding = next_seq_len - curtailed_seq_len
255
+
256
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
215
257
 
216
258
  # the parameters of the memory model stores the memories of the key / values
217
259
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
@@ -231,7 +273,7 @@ class NeuralMemory(Module):
231
273
  # fetch values from memory model
232
274
 
233
275
  curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
234
- queries = rearrange(queries, 'b n d -> (b n) 1 d')
276
+ queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
235
277
 
236
278
  # forward functional call
237
279
 
@@ -239,7 +281,12 @@ class NeuralMemory(Module):
239
281
 
240
282
  # reconstitute batch dimension
241
283
 
242
- values = rearrange(values, '(b n) 1 d -> b n d', b = batch)
284
+ values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
285
+
286
+ # restore
287
+
288
+ values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
289
+ values = values[:, :-padding]
243
290
 
244
291
  return values
245
292
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -49,7 +49,7 @@ Description-Content-Type: text/markdown
49
49
 
50
50
  ## Titans - Pytorch (wip)
51
51
 
52
- Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module.
52
+ Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
53
53
 
54
54
  ## Install
55
55
 
@@ -65,7 +65,10 @@ from titans_pytorch import NeuralMemory
65
65
 
66
66
  x = torch.randn(2, 64, 32)
67
67
 
68
- mem = NeuralMemory(32)
68
+ mem = NeuralMemory(
69
+ dim = 32,
70
+ chunk_size = 2
71
+ )
69
72
 
70
73
  out = mem(x)
71
74
 
@@ -0,0 +1,7 @@
1
+ titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=3Mewuysj0g7iAlfjdqMlJhn9-pKJuOerB1frQmQYXuc,9428
4
+ titans_pytorch-0.0.5.dist-info/METADATA,sha256=f1DgCKZz9nqNfZOrqbOpyn-yEx2v5M5zgGIW0Zeu84I,3032
5
+ titans_pytorch-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ titans_pytorch-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ titans_pytorch-0.0.5.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=0Mh9LJv5hLVbB2MvRJX5QanAeTtU9LAuj6YOQUwsyUQ,7813
4
- titans_pytorch-0.0.3.dist-info/METADATA,sha256=AXfDl_MTIu24VRagi_rgiH8rHXFBU5euwSD6DMwLgsg,2968
5
- titans_pytorch-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- titans_pytorch-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- titans_pytorch-0.0.3.dist-info/RECORD,,