titans-pytorch 0.0.16__tar.gz → 0.0.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.16
3
+ Version: 0.0.18
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.16"
3
+ version = "0.0.18"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -269,7 +269,7 @@ class NeuralMemory(Module):
269
269
  gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
270
270
  inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
271
271
 
272
- outputs = scan(gates, inputs)
272
+ outputs = scan(gates.contiguous(), inputs.contiguous())
273
273
 
274
274
  outputs = outputs[..., :seq_len]
275
275
  outputs = rearrange(outputs, 'b d n -> b n d')
@@ -0,0 +1,419 @@
1
+ from __future__ import annotations
2
+ import math
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch import nn, Tensor
7
+ import torch.nn.functional as F
8
+ from torch.nn import Linear, Module
9
+ from torch.func import functional_call, vmap, grad
10
+
11
+ from tensordict import TensorDict
12
+
13
+ from titans_pytorch.associative_scan import (
14
+ associative_scan,
15
+ binary_operator,
16
+ pad_at_dim
17
+ )
18
+
19
+ import einx
20
+ from einops import rearrange, pack, unpack
21
+ from einops.layers.torch import Rearrange, Reduce
22
+
23
+ """
24
+ ein notation:
25
+ b - batch
26
+ n - sequence
27
+ d - feature dimension
28
+ c - intra-chunk
29
+ """
30
+
31
+ # constants
32
+
33
+ LinearNoBias = partial(Linear, bias = False)
34
+
35
+ # functions
36
+
37
+ def exists(v):
38
+ return v is not None
39
+
40
+ def default(v, d):
41
+ return v if exists(v) else d
42
+
43
+ def round_down_multiple(seq, mult):
44
+ return seq // mult * mult
45
+
46
+ def round_up_multiple(seq, mult):
47
+ return math.ceil(seq / mult) * mult
48
+
49
+ def pack_one_with_inverse(t, pattern):
50
+ packed, packed_shape = pack([t], pattern)
51
+
52
+ def inverse(out, inv_pattern = None):
53
+ inv_pattern = default(inv_pattern, pattern)
54
+ return unpack(out, packed_shape, inv_pattern)[0]
55
+
56
+ return packed, inverse
57
+
58
+ # classes
59
+
60
+ # improvised attention as memory module
61
+ # todo - expand if see signal in experiments (update: not seeing it)
62
+
63
+ class MemoryAttention(Module):
64
+ def __init__(
65
+ self,
66
+ dim
67
+ ):
68
+ super().__init__()
69
+ self.weights = nn.ParameterList([
70
+ nn.Parameter(torch.randn(dim, dim)), # queries
71
+ nn.Parameter(torch.randn(dim, dim)), # keys
72
+ nn.Parameter(torch.randn(dim, dim)), # values weight 1
73
+ nn.Parameter(torch.randn(dim, dim)), # values weight 2
74
+ ])
75
+
76
+ def forward(self, x):
77
+
78
+ assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
79
+
80
+ wq, wk, wv1, wv2 = self.weights
81
+
82
+ q = x @ wq
83
+ k = x @ wk
84
+ v = x @ wv1
85
+
86
+ hidden = F.scaled_dot_product_attention(
87
+ q, k, v,
88
+ is_causal = True
89
+ )
90
+
91
+ return F.silu(hidden) @ wv2
92
+
93
+ # main neural memory
94
+
95
+ def default_loss_fn(pred, target):
96
+ return (pred - target).pow(2).mean(dim = -1).sum()
97
+
98
+ class NeuralMemory(Module):
99
+ def __init__(
100
+ self,
101
+ dim,
102
+ chunk_size = 1,
103
+ dim_head = None,
104
+ heads = 1,
105
+ model: MemoryAttention | None = None,
106
+ store_memory_loss_fn: Callable = default_loss_fn,
107
+ pre_rmsnorm = True,
108
+ post_rmsnorm = True,
109
+ use_accelerated_scan = False,
110
+ default_model_kwargs: dict = dict()
111
+ ):
112
+ super().__init__()
113
+
114
+ # norms
115
+
116
+ self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
117
+ self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
118
+
119
+ self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
120
+
121
+ # maybe multi-headed
122
+
123
+ dim_head = default(dim_head, dim)
124
+ dim_inner = dim_head * heads
125
+
126
+ self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
127
+ self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
128
+ self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
129
+
130
+ # memory mlp
131
+
132
+ if not exists(model):
133
+ model = MemoryAttention(dim_head, **default_model_kwargs)
134
+
135
+ assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
136
+
137
+ # the memory is the weights of the model
138
+
139
+ self.memory_model = model
140
+
141
+ # the chunk size within the paper where adaptive step, momentum, weight decay are shared
142
+
143
+ self.chunk_size = chunk_size
144
+
145
+ # prepare function for per sample gradients from model above, using torch.func
146
+
147
+ def forward_and_loss(params, inputs, target):
148
+ pred = functional_call(self.memory_model, params, inputs)
149
+ loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
150
+ return loss
151
+
152
+ self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
153
+
154
+ # queries for retrieving from the model
155
+
156
+ self.to_queries = LinearNoBias(dim, dim_inner)
157
+
158
+ # keys and values for storing to the model
159
+
160
+ self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
161
+ self.store_memory_loss_fn = store_memory_loss_fn
162
+
163
+ # learned adaptive learning rate and momentum
164
+ # todo - explore mlp layerwise learned lr / momentum
165
+
166
+ self.to_momentum = nn.Sequential(
167
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
168
+ LinearNoBias(dim, heads),
169
+ Rearrange('b n h -> (b h) n 1')
170
+ )
171
+
172
+ self.to_adaptive_step = nn.Sequential(
173
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
174
+ LinearNoBias(dim, heads),
175
+ Rearrange('b n h -> (b h) n')
176
+ )
177
+
178
+ # weight decay factor
179
+
180
+ self.to_decay_factor = nn.Sequential(
181
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
182
+ LinearNoBias(dim, heads),
183
+ Rearrange('b n h -> (b h) n 1')
184
+ )
185
+
186
+ # maybe use accelerated scan
187
+
188
+ self.use_accelerated_scan = use_accelerated_scan
189
+
190
+ def init_weights_and_momentum(self):
191
+ params = TensorDict(dict(self.memory_model.named_parameters()))
192
+
193
+ init_weights = params.clone().zero_()
194
+ init_momentum = params.clone().zero_()
195
+
196
+ return init_weights, init_momentum
197
+
198
+ def store_memories(
199
+ self,
200
+ seq,
201
+ past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
202
+ ):
203
+
204
+ seq = self.store_norm(seq)
205
+
206
+ # curtail sequence by multiple of the chunk size
207
+ # only a complete chunk of the sequence provides the memory for the next chunk
208
+
209
+ seq_len, chunk_size = seq.shape[-2], self.chunk_size
210
+ round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
211
+
212
+ seq = seq[:, :round_down_seq_len]
213
+
214
+ # curr weights + past weights, in the case that the initial weights are learned
215
+
216
+ curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
217
+
218
+ past_state = tuple(TensorDict(d) for d in past_state)
219
+ past_weights, past_momentum = past_state
220
+
221
+ curr_weights = curr_weights + past_weights
222
+
223
+ # pack batch and sequence dimension
224
+
225
+ adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
226
+
227
+ adaptive_momentum = self.to_momentum(seq).sigmoid()
228
+ decay_factor = self.to_decay_factor(seq).sigmoid()
229
+
230
+ # keys and values
231
+
232
+ keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
233
+
234
+ # maybe multi head
235
+
236
+ keys, values = map(self.split_heads, (keys, values))
237
+
238
+ batch = keys.shape[0]
239
+
240
+ # take care of chunking
241
+
242
+ keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
243
+
244
+ # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
245
+
246
+ grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
247
+
248
+ grads = TensorDict(grads)
249
+
250
+ # restore batch and sequence dimension
251
+
252
+ grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
253
+
254
+ # multiply gradients with learned adaptive step size
255
+
256
+ surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
257
+
258
+ # determine scan function
259
+
260
+ def default_associative_scan(gates, inputs):
261
+ _, outputs = associative_scan(binary_operator, (gates, inputs))
262
+ return outputs
263
+
264
+ if self.use_accelerated_scan:
265
+ from accelerated_scan.triton import scan as triton_scan
266
+ from accelerated_scan.warp import scan as warp_scan
267
+
268
+ scan = triton_scan if seq.is_cuda else warp_scan
269
+
270
+ def accelerate_scan_fn(gates, inputs):
271
+ gates = gates.expand_as(inputs)
272
+ gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
273
+
274
+ seq_len = gates.shape[-1]
275
+ next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
276
+
277
+ gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
278
+ inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
279
+
280
+ outputs = scan(gates, inputs)
281
+
282
+ outputs = outputs[..., :seq_len]
283
+ outputs = rearrange(outputs, 'b d n -> b n d')
284
+ return outputs
285
+
286
+ scan_fn = accelerate_scan_fn
287
+ else:
288
+ scan_fn = default_associative_scan
289
+
290
+ # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
291
+
292
+ next_momentum = TensorDict()
293
+ updates = TensorDict()
294
+
295
+ for param_name, surprise in surprises.items():
296
+
297
+ surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
298
+
299
+ # derive momentum with associative scan - eq (10)
300
+
301
+ momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
302
+
303
+ # use associative scan again for learned forgetting (weight decay) - eq (13)
304
+
305
+ update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
306
+
307
+ updates[param_name] = inverse_pack(update)
308
+ next_momentum[param_name] = inverse_pack(momentum)
309
+
310
+ # compute the next weight per batch
311
+
312
+ last_update = updates.apply(lambda t: t[:, -1])
313
+
314
+ next_state = (curr_weights + last_update, next_momentum)
315
+
316
+ return updates, next_state
317
+
318
+ def retrieve_memories(
319
+ self,
320
+ seq,
321
+ past_weights: dict[str, Tensor] | None = None,
322
+ ):
323
+ chunk_size = self.chunk_size
324
+ seq_len = seq.shape[1]
325
+
326
+ seq = self.retrieve_norm(seq)
327
+
328
+ assert seq_len > chunk_size
329
+
330
+ seq = seq[:, chunk_size:]
331
+ curtailed_seq_len = seq.shape[-2]
332
+
333
+ next_seq_len = round_up_multiple(curtailed_seq_len + 1, chunk_size)
334
+
335
+ padding = next_seq_len - curtailed_seq_len
336
+
337
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
338
+
339
+ # the parameters of the memory model stores the memories of the key / values
340
+ # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
341
+
342
+ curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
343
+
344
+ if exists(past_weights):
345
+ past_weights = TensorDict(past_weights)
346
+ assert past_weights.keys() == curr_weights.keys()
347
+
348
+ curr_weights = curr_weights + past_weights
349
+
350
+ # sequence Float['b n d'] to queries
351
+
352
+ queries = self.to_queries(seq)
353
+
354
+ # maybe multihead
355
+
356
+ queries = self.split_heads(queries)
357
+
358
+ batch = queries.shape[0]
359
+
360
+ # fetch values from memory model
361
+
362
+ curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
363
+ queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
364
+
365
+ # forward functional call
366
+
367
+ values = functional_call(self.memory_model, dict(curr_weights), queries)
368
+
369
+ # reconstitute batch dimension
370
+
371
+ values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
372
+
373
+ # maybe merge heads and combine
374
+
375
+ values = self.merge_heads(values)
376
+
377
+ values = self.combine_heads(values)
378
+
379
+ # post norm, somehow could not stabilize this without it, not in paper
380
+
381
+ values = self.post_rmsnorm(values)
382
+
383
+ # restore
384
+
385
+ values = pad_at_dim(values, (chunk_size, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
386
+ values = values[:, :-padding]
387
+
388
+ return values
389
+
390
+ def forward(
391
+ self,
392
+ seq,
393
+ store_seq = None,
394
+ past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
395
+ return_next_memories = False
396
+ ):
397
+ batch, seq_len = seq.shape[:2]
398
+
399
+ if seq_len <= self.chunk_size:
400
+ return torch.zeros_like(seq)
401
+
402
+ if exists(past_state):
403
+ past_state = tuple(TensorDict(d) for d in past_state)
404
+
405
+ if not exists(past_state):
406
+ past_state = self.init_weights_and_momentum()
407
+
408
+ store_seq = default(store_seq, seq)
409
+
410
+ updates, next_memories = self.store_memories(store_seq, past_state)
411
+
412
+ past_weights, _ = past_state
413
+
414
+ retrieved = self.retrieve_memories(seq, past_weights + updates)
415
+
416
+ if not return_next_memories:
417
+ return retrieved
418
+
419
+ return retrieved, next_memories
@@ -15,7 +15,6 @@ from taylor_series_linear_attention import TaylorSeriesLinearAttn
15
15
 
16
16
  from titans_pytorch.titans import (
17
17
  NeuralMemory,
18
- MemoryAttention,
19
18
  MemoryMLP
20
19
  )
21
20
 
File without changes