titans-pytorch 0.0.16__tar.gz → 0.0.17__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/PKG-INFO +1 -1
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/pyproject.toml +1 -1
- titans_pytorch-0.0.17/titans_pytorch/titans_attn_memory.py +419 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/.gitignore +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/LICENSE +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/README.md +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/data/README.md +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/fig1.png +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/fig2.png +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/requirements.txt +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.17}/train.py +0 -0
@@ -0,0 +1,419 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import math
|
3
|
+
from functools import partial
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch import nn, Tensor
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from torch.nn import Linear, Module
|
9
|
+
from torch.func import functional_call, vmap, grad
|
10
|
+
|
11
|
+
from tensordict import TensorDict
|
12
|
+
|
13
|
+
from titans_pytorch.associative_scan import (
|
14
|
+
associative_scan,
|
15
|
+
binary_operator,
|
16
|
+
pad_at_dim
|
17
|
+
)
|
18
|
+
|
19
|
+
import einx
|
20
|
+
from einops import rearrange, pack, unpack
|
21
|
+
from einops.layers.torch import Rearrange, Reduce
|
22
|
+
|
23
|
+
"""
|
24
|
+
ein notation:
|
25
|
+
b - batch
|
26
|
+
n - sequence
|
27
|
+
d - feature dimension
|
28
|
+
c - intra-chunk
|
29
|
+
"""
|
30
|
+
|
31
|
+
# constants
|
32
|
+
|
33
|
+
LinearNoBias = partial(Linear, bias = False)
|
34
|
+
|
35
|
+
# functions
|
36
|
+
|
37
|
+
def exists(v):
|
38
|
+
return v is not None
|
39
|
+
|
40
|
+
def default(v, d):
|
41
|
+
return v if exists(v) else d
|
42
|
+
|
43
|
+
def round_down_multiple(seq, mult):
|
44
|
+
return seq // mult * mult
|
45
|
+
|
46
|
+
def round_up_multiple(seq, mult):
|
47
|
+
return math.ceil(seq / mult) * mult
|
48
|
+
|
49
|
+
def pack_one_with_inverse(t, pattern):
|
50
|
+
packed, packed_shape = pack([t], pattern)
|
51
|
+
|
52
|
+
def inverse(out, inv_pattern = None):
|
53
|
+
inv_pattern = default(inv_pattern, pattern)
|
54
|
+
return unpack(out, packed_shape, inv_pattern)[0]
|
55
|
+
|
56
|
+
return packed, inverse
|
57
|
+
|
58
|
+
# classes
|
59
|
+
|
60
|
+
# improvised attention as memory module
|
61
|
+
# todo - expand if see signal in experiments (update: not seeing it)
|
62
|
+
|
63
|
+
class MemoryAttention(Module):
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
dim
|
67
|
+
):
|
68
|
+
super().__init__()
|
69
|
+
self.weights = nn.ParameterList([
|
70
|
+
nn.Parameter(torch.randn(dim, dim)), # queries
|
71
|
+
nn.Parameter(torch.randn(dim, dim)), # keys
|
72
|
+
nn.Parameter(torch.randn(dim, dim)), # values weight 1
|
73
|
+
nn.Parameter(torch.randn(dim, dim)), # values weight 2
|
74
|
+
])
|
75
|
+
|
76
|
+
def forward(self, x):
|
77
|
+
|
78
|
+
assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
|
79
|
+
|
80
|
+
wq, wk, wv1, wv2 = self.weights
|
81
|
+
|
82
|
+
q = x @ wq
|
83
|
+
k = x @ wk
|
84
|
+
v = x @ wv1
|
85
|
+
|
86
|
+
hidden = F.scaled_dot_product_attention(
|
87
|
+
q, k, v,
|
88
|
+
is_causal = True
|
89
|
+
)
|
90
|
+
|
91
|
+
return F.silu(hidden) @ wv2
|
92
|
+
|
93
|
+
# main neural memory
|
94
|
+
|
95
|
+
def default_loss_fn(pred, target):
|
96
|
+
return (pred - target).pow(2).mean(dim = -1).sum()
|
97
|
+
|
98
|
+
class NeuralMemory(Module):
|
99
|
+
def __init__(
|
100
|
+
self,
|
101
|
+
dim,
|
102
|
+
chunk_size = 1,
|
103
|
+
dim_head = None,
|
104
|
+
heads = 1,
|
105
|
+
model: MemoryAttention | None = None,
|
106
|
+
store_memory_loss_fn: Callable = default_loss_fn,
|
107
|
+
pre_rmsnorm = True,
|
108
|
+
post_rmsnorm = True,
|
109
|
+
use_accelerated_scan = False,
|
110
|
+
default_model_kwargs: dict = dict()
|
111
|
+
):
|
112
|
+
super().__init__()
|
113
|
+
|
114
|
+
# norms
|
115
|
+
|
116
|
+
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
117
|
+
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
118
|
+
|
119
|
+
self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
|
120
|
+
|
121
|
+
# maybe multi-headed
|
122
|
+
|
123
|
+
dim_head = default(dim_head, dim)
|
124
|
+
dim_inner = dim_head * heads
|
125
|
+
|
126
|
+
self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
|
127
|
+
self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
|
128
|
+
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
129
|
+
|
130
|
+
# memory mlp
|
131
|
+
|
132
|
+
if not exists(model):
|
133
|
+
model = MemoryAttention(dim_head, **default_model_kwargs)
|
134
|
+
|
135
|
+
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
136
|
+
|
137
|
+
# the memory is the weights of the model
|
138
|
+
|
139
|
+
self.memory_model = model
|
140
|
+
|
141
|
+
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
142
|
+
|
143
|
+
self.chunk_size = chunk_size
|
144
|
+
|
145
|
+
# prepare function for per sample gradients from model above, using torch.func
|
146
|
+
|
147
|
+
def forward_and_loss(params, inputs, target):
|
148
|
+
pred = functional_call(self.memory_model, params, inputs)
|
149
|
+
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
150
|
+
return loss
|
151
|
+
|
152
|
+
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
|
153
|
+
|
154
|
+
# queries for retrieving from the model
|
155
|
+
|
156
|
+
self.to_queries = LinearNoBias(dim, dim_inner)
|
157
|
+
|
158
|
+
# keys and values for storing to the model
|
159
|
+
|
160
|
+
self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
|
161
|
+
self.store_memory_loss_fn = store_memory_loss_fn
|
162
|
+
|
163
|
+
# learned adaptive learning rate and momentum
|
164
|
+
# todo - explore mlp layerwise learned lr / momentum
|
165
|
+
|
166
|
+
self.to_momentum = nn.Sequential(
|
167
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
168
|
+
LinearNoBias(dim, heads),
|
169
|
+
Rearrange('b n h -> (b h) n 1')
|
170
|
+
)
|
171
|
+
|
172
|
+
self.to_adaptive_step = nn.Sequential(
|
173
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
174
|
+
LinearNoBias(dim, heads),
|
175
|
+
Rearrange('b n h -> (b h) n')
|
176
|
+
)
|
177
|
+
|
178
|
+
# weight decay factor
|
179
|
+
|
180
|
+
self.to_decay_factor = nn.Sequential(
|
181
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
182
|
+
LinearNoBias(dim, heads),
|
183
|
+
Rearrange('b n h -> (b h) n 1')
|
184
|
+
)
|
185
|
+
|
186
|
+
# maybe use accelerated scan
|
187
|
+
|
188
|
+
self.use_accelerated_scan = use_accelerated_scan
|
189
|
+
|
190
|
+
def init_weights_and_momentum(self):
|
191
|
+
params = TensorDict(dict(self.memory_model.named_parameters()))
|
192
|
+
|
193
|
+
init_weights = params.clone().zero_()
|
194
|
+
init_momentum = params.clone().zero_()
|
195
|
+
|
196
|
+
return init_weights, init_momentum
|
197
|
+
|
198
|
+
def store_memories(
|
199
|
+
self,
|
200
|
+
seq,
|
201
|
+
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
|
202
|
+
):
|
203
|
+
|
204
|
+
seq = self.store_norm(seq)
|
205
|
+
|
206
|
+
# curtail sequence by multiple of the chunk size
|
207
|
+
# only a complete chunk of the sequence provides the memory for the next chunk
|
208
|
+
|
209
|
+
seq_len, chunk_size = seq.shape[-2], self.chunk_size
|
210
|
+
round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
|
211
|
+
|
212
|
+
seq = seq[:, :round_down_seq_len]
|
213
|
+
|
214
|
+
# curr weights + past weights, in the case that the initial weights are learned
|
215
|
+
|
216
|
+
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
217
|
+
|
218
|
+
past_state = tuple(TensorDict(d) for d in past_state)
|
219
|
+
past_weights, past_momentum = past_state
|
220
|
+
|
221
|
+
curr_weights = curr_weights + past_weights
|
222
|
+
|
223
|
+
# pack batch and sequence dimension
|
224
|
+
|
225
|
+
adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
|
226
|
+
|
227
|
+
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
228
|
+
decay_factor = self.to_decay_factor(seq).sigmoid()
|
229
|
+
|
230
|
+
# keys and values
|
231
|
+
|
232
|
+
keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
|
233
|
+
|
234
|
+
# maybe multi head
|
235
|
+
|
236
|
+
keys, values = map(self.split_heads, (keys, values))
|
237
|
+
|
238
|
+
batch = keys.shape[0]
|
239
|
+
|
240
|
+
# take care of chunking
|
241
|
+
|
242
|
+
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
|
243
|
+
|
244
|
+
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
245
|
+
|
246
|
+
grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
|
247
|
+
|
248
|
+
grads = TensorDict(grads)
|
249
|
+
|
250
|
+
# restore batch and sequence dimension
|
251
|
+
|
252
|
+
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
253
|
+
|
254
|
+
# multiply gradients with learned adaptive step size
|
255
|
+
|
256
|
+
surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
|
257
|
+
|
258
|
+
# determine scan function
|
259
|
+
|
260
|
+
def default_associative_scan(gates, inputs):
|
261
|
+
_, outputs = associative_scan(binary_operator, (gates, inputs))
|
262
|
+
return outputs
|
263
|
+
|
264
|
+
if self.use_accelerated_scan:
|
265
|
+
from accelerated_scan.triton import scan as triton_scan
|
266
|
+
from accelerated_scan.warp import scan as warp_scan
|
267
|
+
|
268
|
+
scan = triton_scan if seq.is_cuda else warp_scan
|
269
|
+
|
270
|
+
def accelerate_scan_fn(gates, inputs):
|
271
|
+
gates = gates.expand_as(inputs)
|
272
|
+
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
273
|
+
|
274
|
+
seq_len = gates.shape[-1]
|
275
|
+
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
276
|
+
|
277
|
+
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
278
|
+
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
279
|
+
|
280
|
+
outputs = scan(gates, inputs)
|
281
|
+
|
282
|
+
outputs = outputs[..., :seq_len]
|
283
|
+
outputs = rearrange(outputs, 'b d n -> b n d')
|
284
|
+
return outputs
|
285
|
+
|
286
|
+
scan_fn = accelerate_scan_fn
|
287
|
+
else:
|
288
|
+
scan_fn = default_associative_scan
|
289
|
+
|
290
|
+
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
291
|
+
|
292
|
+
next_momentum = TensorDict()
|
293
|
+
updates = TensorDict()
|
294
|
+
|
295
|
+
for param_name, surprise in surprises.items():
|
296
|
+
|
297
|
+
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
|
298
|
+
|
299
|
+
# derive momentum with associative scan - eq (10)
|
300
|
+
|
301
|
+
momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
|
302
|
+
|
303
|
+
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
304
|
+
|
305
|
+
update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
|
306
|
+
|
307
|
+
updates[param_name] = inverse_pack(update)
|
308
|
+
next_momentum[param_name] = inverse_pack(momentum)
|
309
|
+
|
310
|
+
# compute the next weight per batch
|
311
|
+
|
312
|
+
last_update = updates.apply(lambda t: t[:, -1])
|
313
|
+
|
314
|
+
next_state = (curr_weights + last_update, next_momentum)
|
315
|
+
|
316
|
+
return updates, next_state
|
317
|
+
|
318
|
+
def retrieve_memories(
|
319
|
+
self,
|
320
|
+
seq,
|
321
|
+
past_weights: dict[str, Tensor] | None = None,
|
322
|
+
):
|
323
|
+
chunk_size = self.chunk_size
|
324
|
+
seq_len = seq.shape[1]
|
325
|
+
|
326
|
+
seq = self.retrieve_norm(seq)
|
327
|
+
|
328
|
+
assert seq_len > chunk_size
|
329
|
+
|
330
|
+
seq = seq[:, chunk_size:]
|
331
|
+
curtailed_seq_len = seq.shape[-2]
|
332
|
+
|
333
|
+
next_seq_len = round_up_multiple(curtailed_seq_len + 1, chunk_size)
|
334
|
+
|
335
|
+
padding = next_seq_len - curtailed_seq_len
|
336
|
+
|
337
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
338
|
+
|
339
|
+
# the parameters of the memory model stores the memories of the key / values
|
340
|
+
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
341
|
+
|
342
|
+
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
343
|
+
|
344
|
+
if exists(past_weights):
|
345
|
+
past_weights = TensorDict(past_weights)
|
346
|
+
assert past_weights.keys() == curr_weights.keys()
|
347
|
+
|
348
|
+
curr_weights = curr_weights + past_weights
|
349
|
+
|
350
|
+
# sequence Float['b n d'] to queries
|
351
|
+
|
352
|
+
queries = self.to_queries(seq)
|
353
|
+
|
354
|
+
# maybe multihead
|
355
|
+
|
356
|
+
queries = self.split_heads(queries)
|
357
|
+
|
358
|
+
batch = queries.shape[0]
|
359
|
+
|
360
|
+
# fetch values from memory model
|
361
|
+
|
362
|
+
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
363
|
+
queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
|
364
|
+
|
365
|
+
# forward functional call
|
366
|
+
|
367
|
+
values = functional_call(self.memory_model, dict(curr_weights), queries)
|
368
|
+
|
369
|
+
# reconstitute batch dimension
|
370
|
+
|
371
|
+
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
372
|
+
|
373
|
+
# maybe merge heads and combine
|
374
|
+
|
375
|
+
values = self.merge_heads(values)
|
376
|
+
|
377
|
+
values = self.combine_heads(values)
|
378
|
+
|
379
|
+
# post norm, somehow could not stabilize this without it, not in paper
|
380
|
+
|
381
|
+
values = self.post_rmsnorm(values)
|
382
|
+
|
383
|
+
# restore
|
384
|
+
|
385
|
+
values = pad_at_dim(values, (chunk_size, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
|
386
|
+
values = values[:, :-padding]
|
387
|
+
|
388
|
+
return values
|
389
|
+
|
390
|
+
def forward(
|
391
|
+
self,
|
392
|
+
seq,
|
393
|
+
store_seq = None,
|
394
|
+
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
395
|
+
return_next_memories = False
|
396
|
+
):
|
397
|
+
batch, seq_len = seq.shape[:2]
|
398
|
+
|
399
|
+
if seq_len <= self.chunk_size:
|
400
|
+
return torch.zeros_like(seq)
|
401
|
+
|
402
|
+
if exists(past_state):
|
403
|
+
past_state = tuple(TensorDict(d) for d in past_state)
|
404
|
+
|
405
|
+
if not exists(past_state):
|
406
|
+
past_state = self.init_weights_and_momentum()
|
407
|
+
|
408
|
+
store_seq = default(store_seq, seq)
|
409
|
+
|
410
|
+
updates, next_memories = self.store_memories(store_seq, past_state)
|
411
|
+
|
412
|
+
past_weights, _ = past_state
|
413
|
+
|
414
|
+
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
415
|
+
|
416
|
+
if not return_next_memories:
|
417
|
+
return retrieved
|
418
|
+
|
419
|
+
return retrieved, next_memories
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|