titans-pytorch 0.0.16__py3-none-any.whl → 0.0.18__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
titans_pytorch/titans.py CHANGED
@@ -269,7 +269,7 @@ class NeuralMemory(Module):
269
269
  gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
270
270
  inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
271
271
 
272
- outputs = scan(gates, inputs)
272
+ outputs = scan(gates.contiguous(), inputs.contiguous())
273
273
 
274
274
  outputs = outputs[..., :seq_len]
275
275
  outputs = rearrange(outputs, 'b d n -> b n d')
@@ -0,0 +1,419 @@
1
+ from __future__ import annotations
2
+ import math
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch import nn, Tensor
7
+ import torch.nn.functional as F
8
+ from torch.nn import Linear, Module
9
+ from torch.func import functional_call, vmap, grad
10
+
11
+ from tensordict import TensorDict
12
+
13
+ from titans_pytorch.associative_scan import (
14
+ associative_scan,
15
+ binary_operator,
16
+ pad_at_dim
17
+ )
18
+
19
+ import einx
20
+ from einops import rearrange, pack, unpack
21
+ from einops.layers.torch import Rearrange, Reduce
22
+
23
+ """
24
+ ein notation:
25
+ b - batch
26
+ n - sequence
27
+ d - feature dimension
28
+ c - intra-chunk
29
+ """
30
+
31
+ # constants
32
+
33
+ LinearNoBias = partial(Linear, bias = False)
34
+
35
+ # functions
36
+
37
+ def exists(v):
38
+ return v is not None
39
+
40
+ def default(v, d):
41
+ return v if exists(v) else d
42
+
43
+ def round_down_multiple(seq, mult):
44
+ return seq // mult * mult
45
+
46
+ def round_up_multiple(seq, mult):
47
+ return math.ceil(seq / mult) * mult
48
+
49
+ def pack_one_with_inverse(t, pattern):
50
+ packed, packed_shape = pack([t], pattern)
51
+
52
+ def inverse(out, inv_pattern = None):
53
+ inv_pattern = default(inv_pattern, pattern)
54
+ return unpack(out, packed_shape, inv_pattern)[0]
55
+
56
+ return packed, inverse
57
+
58
+ # classes
59
+
60
+ # improvised attention as memory module
61
+ # todo - expand if see signal in experiments (update: not seeing it)
62
+
63
+ class MemoryAttention(Module):
64
+ def __init__(
65
+ self,
66
+ dim
67
+ ):
68
+ super().__init__()
69
+ self.weights = nn.ParameterList([
70
+ nn.Parameter(torch.randn(dim, dim)), # queries
71
+ nn.Parameter(torch.randn(dim, dim)), # keys
72
+ nn.Parameter(torch.randn(dim, dim)), # values weight 1
73
+ nn.Parameter(torch.randn(dim, dim)), # values weight 2
74
+ ])
75
+
76
+ def forward(self, x):
77
+
78
+ assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
79
+
80
+ wq, wk, wv1, wv2 = self.weights
81
+
82
+ q = x @ wq
83
+ k = x @ wk
84
+ v = x @ wv1
85
+
86
+ hidden = F.scaled_dot_product_attention(
87
+ q, k, v,
88
+ is_causal = True
89
+ )
90
+
91
+ return F.silu(hidden) @ wv2
92
+
93
+ # main neural memory
94
+
95
+ def default_loss_fn(pred, target):
96
+ return (pred - target).pow(2).mean(dim = -1).sum()
97
+
98
+ class NeuralMemory(Module):
99
+ def __init__(
100
+ self,
101
+ dim,
102
+ chunk_size = 1,
103
+ dim_head = None,
104
+ heads = 1,
105
+ model: MemoryAttention | None = None,
106
+ store_memory_loss_fn: Callable = default_loss_fn,
107
+ pre_rmsnorm = True,
108
+ post_rmsnorm = True,
109
+ use_accelerated_scan = False,
110
+ default_model_kwargs: dict = dict()
111
+ ):
112
+ super().__init__()
113
+
114
+ # norms
115
+
116
+ self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
117
+ self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
118
+
119
+ self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
120
+
121
+ # maybe multi-headed
122
+
123
+ dim_head = default(dim_head, dim)
124
+ dim_inner = dim_head * heads
125
+
126
+ self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
127
+ self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
128
+ self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
129
+
130
+ # memory mlp
131
+
132
+ if not exists(model):
133
+ model = MemoryAttention(dim_head, **default_model_kwargs)
134
+
135
+ assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
136
+
137
+ # the memory is the weights of the model
138
+
139
+ self.memory_model = model
140
+
141
+ # the chunk size within the paper where adaptive step, momentum, weight decay are shared
142
+
143
+ self.chunk_size = chunk_size
144
+
145
+ # prepare function for per sample gradients from model above, using torch.func
146
+
147
+ def forward_and_loss(params, inputs, target):
148
+ pred = functional_call(self.memory_model, params, inputs)
149
+ loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
150
+ return loss
151
+
152
+ self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
153
+
154
+ # queries for retrieving from the model
155
+
156
+ self.to_queries = LinearNoBias(dim, dim_inner)
157
+
158
+ # keys and values for storing to the model
159
+
160
+ self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
161
+ self.store_memory_loss_fn = store_memory_loss_fn
162
+
163
+ # learned adaptive learning rate and momentum
164
+ # todo - explore mlp layerwise learned lr / momentum
165
+
166
+ self.to_momentum = nn.Sequential(
167
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
168
+ LinearNoBias(dim, heads),
169
+ Rearrange('b n h -> (b h) n 1')
170
+ )
171
+
172
+ self.to_adaptive_step = nn.Sequential(
173
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
174
+ LinearNoBias(dim, heads),
175
+ Rearrange('b n h -> (b h) n')
176
+ )
177
+
178
+ # weight decay factor
179
+
180
+ self.to_decay_factor = nn.Sequential(
181
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
182
+ LinearNoBias(dim, heads),
183
+ Rearrange('b n h -> (b h) n 1')
184
+ )
185
+
186
+ # maybe use accelerated scan
187
+
188
+ self.use_accelerated_scan = use_accelerated_scan
189
+
190
+ def init_weights_and_momentum(self):
191
+ params = TensorDict(dict(self.memory_model.named_parameters()))
192
+
193
+ init_weights = params.clone().zero_()
194
+ init_momentum = params.clone().zero_()
195
+
196
+ return init_weights, init_momentum
197
+
198
+ def store_memories(
199
+ self,
200
+ seq,
201
+ past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
202
+ ):
203
+
204
+ seq = self.store_norm(seq)
205
+
206
+ # curtail sequence by multiple of the chunk size
207
+ # only a complete chunk of the sequence provides the memory for the next chunk
208
+
209
+ seq_len, chunk_size = seq.shape[-2], self.chunk_size
210
+ round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
211
+
212
+ seq = seq[:, :round_down_seq_len]
213
+
214
+ # curr weights + past weights, in the case that the initial weights are learned
215
+
216
+ curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
217
+
218
+ past_state = tuple(TensorDict(d) for d in past_state)
219
+ past_weights, past_momentum = past_state
220
+
221
+ curr_weights = curr_weights + past_weights
222
+
223
+ # pack batch and sequence dimension
224
+
225
+ adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
226
+
227
+ adaptive_momentum = self.to_momentum(seq).sigmoid()
228
+ decay_factor = self.to_decay_factor(seq).sigmoid()
229
+
230
+ # keys and values
231
+
232
+ keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
233
+
234
+ # maybe multi head
235
+
236
+ keys, values = map(self.split_heads, (keys, values))
237
+
238
+ batch = keys.shape[0]
239
+
240
+ # take care of chunking
241
+
242
+ keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
243
+
244
+ # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
245
+
246
+ grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
247
+
248
+ grads = TensorDict(grads)
249
+
250
+ # restore batch and sequence dimension
251
+
252
+ grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
253
+
254
+ # multiply gradients with learned adaptive step size
255
+
256
+ surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
257
+
258
+ # determine scan function
259
+
260
+ def default_associative_scan(gates, inputs):
261
+ _, outputs = associative_scan(binary_operator, (gates, inputs))
262
+ return outputs
263
+
264
+ if self.use_accelerated_scan:
265
+ from accelerated_scan.triton import scan as triton_scan
266
+ from accelerated_scan.warp import scan as warp_scan
267
+
268
+ scan = triton_scan if seq.is_cuda else warp_scan
269
+
270
+ def accelerate_scan_fn(gates, inputs):
271
+ gates = gates.expand_as(inputs)
272
+ gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
273
+
274
+ seq_len = gates.shape[-1]
275
+ next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
276
+
277
+ gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
278
+ inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
279
+
280
+ outputs = scan(gates, inputs)
281
+
282
+ outputs = outputs[..., :seq_len]
283
+ outputs = rearrange(outputs, 'b d n -> b n d')
284
+ return outputs
285
+
286
+ scan_fn = accelerate_scan_fn
287
+ else:
288
+ scan_fn = default_associative_scan
289
+
290
+ # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
291
+
292
+ next_momentum = TensorDict()
293
+ updates = TensorDict()
294
+
295
+ for param_name, surprise in surprises.items():
296
+
297
+ surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
298
+
299
+ # derive momentum with associative scan - eq (10)
300
+
301
+ momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
302
+
303
+ # use associative scan again for learned forgetting (weight decay) - eq (13)
304
+
305
+ update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
306
+
307
+ updates[param_name] = inverse_pack(update)
308
+ next_momentum[param_name] = inverse_pack(momentum)
309
+
310
+ # compute the next weight per batch
311
+
312
+ last_update = updates.apply(lambda t: t[:, -1])
313
+
314
+ next_state = (curr_weights + last_update, next_momentum)
315
+
316
+ return updates, next_state
317
+
318
+ def retrieve_memories(
319
+ self,
320
+ seq,
321
+ past_weights: dict[str, Tensor] | None = None,
322
+ ):
323
+ chunk_size = self.chunk_size
324
+ seq_len = seq.shape[1]
325
+
326
+ seq = self.retrieve_norm(seq)
327
+
328
+ assert seq_len > chunk_size
329
+
330
+ seq = seq[:, chunk_size:]
331
+ curtailed_seq_len = seq.shape[-2]
332
+
333
+ next_seq_len = round_up_multiple(curtailed_seq_len + 1, chunk_size)
334
+
335
+ padding = next_seq_len - curtailed_seq_len
336
+
337
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
338
+
339
+ # the parameters of the memory model stores the memories of the key / values
340
+ # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
341
+
342
+ curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
343
+
344
+ if exists(past_weights):
345
+ past_weights = TensorDict(past_weights)
346
+ assert past_weights.keys() == curr_weights.keys()
347
+
348
+ curr_weights = curr_weights + past_weights
349
+
350
+ # sequence Float['b n d'] to queries
351
+
352
+ queries = self.to_queries(seq)
353
+
354
+ # maybe multihead
355
+
356
+ queries = self.split_heads(queries)
357
+
358
+ batch = queries.shape[0]
359
+
360
+ # fetch values from memory model
361
+
362
+ curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
363
+ queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
364
+
365
+ # forward functional call
366
+
367
+ values = functional_call(self.memory_model, dict(curr_weights), queries)
368
+
369
+ # reconstitute batch dimension
370
+
371
+ values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
372
+
373
+ # maybe merge heads and combine
374
+
375
+ values = self.merge_heads(values)
376
+
377
+ values = self.combine_heads(values)
378
+
379
+ # post norm, somehow could not stabilize this without it, not in paper
380
+
381
+ values = self.post_rmsnorm(values)
382
+
383
+ # restore
384
+
385
+ values = pad_at_dim(values, (chunk_size, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
386
+ values = values[:, :-padding]
387
+
388
+ return values
389
+
390
+ def forward(
391
+ self,
392
+ seq,
393
+ store_seq = None,
394
+ past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
395
+ return_next_memories = False
396
+ ):
397
+ batch, seq_len = seq.shape[:2]
398
+
399
+ if seq_len <= self.chunk_size:
400
+ return torch.zeros_like(seq)
401
+
402
+ if exists(past_state):
403
+ past_state = tuple(TensorDict(d) for d in past_state)
404
+
405
+ if not exists(past_state):
406
+ past_state = self.init_weights_and_momentum()
407
+
408
+ store_seq = default(store_seq, seq)
409
+
410
+ updates, next_memories = self.store_memories(store_seq, past_state)
411
+
412
+ past_weights, _ = past_state
413
+
414
+ retrieved = self.retrieve_memories(seq, past_weights + updates)
415
+
416
+ if not return_next_memories:
417
+ return retrieved
418
+
419
+ return retrieved, next_memories
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.16
3
+ Version: 0.0.18
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=-Xv3ufD2vhprNFliuu1lGx27nx7AvHi6yFG2g9eHaqY,12295
4
+ titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
+ titans_pytorch-0.0.18.dist-info/METADATA,sha256=YX0EPMqVioQjAVxoI3CAKV8nWgwZZ0tw4djgud4bEqs,3811
6
+ titans_pytorch-0.0.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.18.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.18.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=HYm0R_1w3s8MNPsyE2qAVpHGqTBX_AoWtjzxRfF1Ams,12269
4
- titans_pytorch-0.0.16.dist-info/METADATA,sha256=mzSgA4okWWSh97fncN2UKJaTVa3PWOHYVaFkQ1Ker0w,3811
5
- titans_pytorch-0.0.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- titans_pytorch-0.0.16.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- titans_pytorch-0.0.16.dist-info/RECORD,,