titans-pytorch 0.0.16__tar.gz → 0.0.18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/PKG-INFO +1 -1
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/pyproject.toml +1 -1
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/titans_pytorch/titans.py +1 -1
- titans_pytorch-0.0.18/titans_pytorch/titans_attn_memory.py +419 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/train.py +0 -1
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/.gitignore +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/LICENSE +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/README.md +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/data/README.md +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/fig1.png +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/fig2.png +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/requirements.txt +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.16 → titans_pytorch-0.0.18}/titans_pytorch/associative_scan.py +0 -0
@@ -269,7 +269,7 @@ class NeuralMemory(Module):
|
|
269
269
|
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
270
270
|
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
271
271
|
|
272
|
-
outputs = scan(gates, inputs)
|
272
|
+
outputs = scan(gates.contiguous(), inputs.contiguous())
|
273
273
|
|
274
274
|
outputs = outputs[..., :seq_len]
|
275
275
|
outputs = rearrange(outputs, 'b d n -> b n d')
|
@@ -0,0 +1,419 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import math
|
3
|
+
from functools import partial
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch import nn, Tensor
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from torch.nn import Linear, Module
|
9
|
+
from torch.func import functional_call, vmap, grad
|
10
|
+
|
11
|
+
from tensordict import TensorDict
|
12
|
+
|
13
|
+
from titans_pytorch.associative_scan import (
|
14
|
+
associative_scan,
|
15
|
+
binary_operator,
|
16
|
+
pad_at_dim
|
17
|
+
)
|
18
|
+
|
19
|
+
import einx
|
20
|
+
from einops import rearrange, pack, unpack
|
21
|
+
from einops.layers.torch import Rearrange, Reduce
|
22
|
+
|
23
|
+
"""
|
24
|
+
ein notation:
|
25
|
+
b - batch
|
26
|
+
n - sequence
|
27
|
+
d - feature dimension
|
28
|
+
c - intra-chunk
|
29
|
+
"""
|
30
|
+
|
31
|
+
# constants
|
32
|
+
|
33
|
+
LinearNoBias = partial(Linear, bias = False)
|
34
|
+
|
35
|
+
# functions
|
36
|
+
|
37
|
+
def exists(v):
|
38
|
+
return v is not None
|
39
|
+
|
40
|
+
def default(v, d):
|
41
|
+
return v if exists(v) else d
|
42
|
+
|
43
|
+
def round_down_multiple(seq, mult):
|
44
|
+
return seq // mult * mult
|
45
|
+
|
46
|
+
def round_up_multiple(seq, mult):
|
47
|
+
return math.ceil(seq / mult) * mult
|
48
|
+
|
49
|
+
def pack_one_with_inverse(t, pattern):
|
50
|
+
packed, packed_shape = pack([t], pattern)
|
51
|
+
|
52
|
+
def inverse(out, inv_pattern = None):
|
53
|
+
inv_pattern = default(inv_pattern, pattern)
|
54
|
+
return unpack(out, packed_shape, inv_pattern)[0]
|
55
|
+
|
56
|
+
return packed, inverse
|
57
|
+
|
58
|
+
# classes
|
59
|
+
|
60
|
+
# improvised attention as memory module
|
61
|
+
# todo - expand if see signal in experiments (update: not seeing it)
|
62
|
+
|
63
|
+
class MemoryAttention(Module):
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
dim
|
67
|
+
):
|
68
|
+
super().__init__()
|
69
|
+
self.weights = nn.ParameterList([
|
70
|
+
nn.Parameter(torch.randn(dim, dim)), # queries
|
71
|
+
nn.Parameter(torch.randn(dim, dim)), # keys
|
72
|
+
nn.Parameter(torch.randn(dim, dim)), # values weight 1
|
73
|
+
nn.Parameter(torch.randn(dim, dim)), # values weight 2
|
74
|
+
])
|
75
|
+
|
76
|
+
def forward(self, x):
|
77
|
+
|
78
|
+
assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
|
79
|
+
|
80
|
+
wq, wk, wv1, wv2 = self.weights
|
81
|
+
|
82
|
+
q = x @ wq
|
83
|
+
k = x @ wk
|
84
|
+
v = x @ wv1
|
85
|
+
|
86
|
+
hidden = F.scaled_dot_product_attention(
|
87
|
+
q, k, v,
|
88
|
+
is_causal = True
|
89
|
+
)
|
90
|
+
|
91
|
+
return F.silu(hidden) @ wv2
|
92
|
+
|
93
|
+
# main neural memory
|
94
|
+
|
95
|
+
def default_loss_fn(pred, target):
|
96
|
+
return (pred - target).pow(2).mean(dim = -1).sum()
|
97
|
+
|
98
|
+
class NeuralMemory(Module):
|
99
|
+
def __init__(
|
100
|
+
self,
|
101
|
+
dim,
|
102
|
+
chunk_size = 1,
|
103
|
+
dim_head = None,
|
104
|
+
heads = 1,
|
105
|
+
model: MemoryAttention | None = None,
|
106
|
+
store_memory_loss_fn: Callable = default_loss_fn,
|
107
|
+
pre_rmsnorm = True,
|
108
|
+
post_rmsnorm = True,
|
109
|
+
use_accelerated_scan = False,
|
110
|
+
default_model_kwargs: dict = dict()
|
111
|
+
):
|
112
|
+
super().__init__()
|
113
|
+
|
114
|
+
# norms
|
115
|
+
|
116
|
+
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
117
|
+
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
118
|
+
|
119
|
+
self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
|
120
|
+
|
121
|
+
# maybe multi-headed
|
122
|
+
|
123
|
+
dim_head = default(dim_head, dim)
|
124
|
+
dim_inner = dim_head * heads
|
125
|
+
|
126
|
+
self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
|
127
|
+
self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
|
128
|
+
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
129
|
+
|
130
|
+
# memory mlp
|
131
|
+
|
132
|
+
if not exists(model):
|
133
|
+
model = MemoryAttention(dim_head, **default_model_kwargs)
|
134
|
+
|
135
|
+
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
136
|
+
|
137
|
+
# the memory is the weights of the model
|
138
|
+
|
139
|
+
self.memory_model = model
|
140
|
+
|
141
|
+
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
142
|
+
|
143
|
+
self.chunk_size = chunk_size
|
144
|
+
|
145
|
+
# prepare function for per sample gradients from model above, using torch.func
|
146
|
+
|
147
|
+
def forward_and_loss(params, inputs, target):
|
148
|
+
pred = functional_call(self.memory_model, params, inputs)
|
149
|
+
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
150
|
+
return loss
|
151
|
+
|
152
|
+
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
|
153
|
+
|
154
|
+
# queries for retrieving from the model
|
155
|
+
|
156
|
+
self.to_queries = LinearNoBias(dim, dim_inner)
|
157
|
+
|
158
|
+
# keys and values for storing to the model
|
159
|
+
|
160
|
+
self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
|
161
|
+
self.store_memory_loss_fn = store_memory_loss_fn
|
162
|
+
|
163
|
+
# learned adaptive learning rate and momentum
|
164
|
+
# todo - explore mlp layerwise learned lr / momentum
|
165
|
+
|
166
|
+
self.to_momentum = nn.Sequential(
|
167
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
168
|
+
LinearNoBias(dim, heads),
|
169
|
+
Rearrange('b n h -> (b h) n 1')
|
170
|
+
)
|
171
|
+
|
172
|
+
self.to_adaptive_step = nn.Sequential(
|
173
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
174
|
+
LinearNoBias(dim, heads),
|
175
|
+
Rearrange('b n h -> (b h) n')
|
176
|
+
)
|
177
|
+
|
178
|
+
# weight decay factor
|
179
|
+
|
180
|
+
self.to_decay_factor = nn.Sequential(
|
181
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
182
|
+
LinearNoBias(dim, heads),
|
183
|
+
Rearrange('b n h -> (b h) n 1')
|
184
|
+
)
|
185
|
+
|
186
|
+
# maybe use accelerated scan
|
187
|
+
|
188
|
+
self.use_accelerated_scan = use_accelerated_scan
|
189
|
+
|
190
|
+
def init_weights_and_momentum(self):
|
191
|
+
params = TensorDict(dict(self.memory_model.named_parameters()))
|
192
|
+
|
193
|
+
init_weights = params.clone().zero_()
|
194
|
+
init_momentum = params.clone().zero_()
|
195
|
+
|
196
|
+
return init_weights, init_momentum
|
197
|
+
|
198
|
+
def store_memories(
|
199
|
+
self,
|
200
|
+
seq,
|
201
|
+
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
|
202
|
+
):
|
203
|
+
|
204
|
+
seq = self.store_norm(seq)
|
205
|
+
|
206
|
+
# curtail sequence by multiple of the chunk size
|
207
|
+
# only a complete chunk of the sequence provides the memory for the next chunk
|
208
|
+
|
209
|
+
seq_len, chunk_size = seq.shape[-2], self.chunk_size
|
210
|
+
round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
|
211
|
+
|
212
|
+
seq = seq[:, :round_down_seq_len]
|
213
|
+
|
214
|
+
# curr weights + past weights, in the case that the initial weights are learned
|
215
|
+
|
216
|
+
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
217
|
+
|
218
|
+
past_state = tuple(TensorDict(d) for d in past_state)
|
219
|
+
past_weights, past_momentum = past_state
|
220
|
+
|
221
|
+
curr_weights = curr_weights + past_weights
|
222
|
+
|
223
|
+
# pack batch and sequence dimension
|
224
|
+
|
225
|
+
adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
|
226
|
+
|
227
|
+
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
228
|
+
decay_factor = self.to_decay_factor(seq).sigmoid()
|
229
|
+
|
230
|
+
# keys and values
|
231
|
+
|
232
|
+
keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
|
233
|
+
|
234
|
+
# maybe multi head
|
235
|
+
|
236
|
+
keys, values = map(self.split_heads, (keys, values))
|
237
|
+
|
238
|
+
batch = keys.shape[0]
|
239
|
+
|
240
|
+
# take care of chunking
|
241
|
+
|
242
|
+
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
|
243
|
+
|
244
|
+
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
245
|
+
|
246
|
+
grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
|
247
|
+
|
248
|
+
grads = TensorDict(grads)
|
249
|
+
|
250
|
+
# restore batch and sequence dimension
|
251
|
+
|
252
|
+
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
253
|
+
|
254
|
+
# multiply gradients with learned adaptive step size
|
255
|
+
|
256
|
+
surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
|
257
|
+
|
258
|
+
# determine scan function
|
259
|
+
|
260
|
+
def default_associative_scan(gates, inputs):
|
261
|
+
_, outputs = associative_scan(binary_operator, (gates, inputs))
|
262
|
+
return outputs
|
263
|
+
|
264
|
+
if self.use_accelerated_scan:
|
265
|
+
from accelerated_scan.triton import scan as triton_scan
|
266
|
+
from accelerated_scan.warp import scan as warp_scan
|
267
|
+
|
268
|
+
scan = triton_scan if seq.is_cuda else warp_scan
|
269
|
+
|
270
|
+
def accelerate_scan_fn(gates, inputs):
|
271
|
+
gates = gates.expand_as(inputs)
|
272
|
+
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
273
|
+
|
274
|
+
seq_len = gates.shape[-1]
|
275
|
+
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
276
|
+
|
277
|
+
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
278
|
+
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
279
|
+
|
280
|
+
outputs = scan(gates, inputs)
|
281
|
+
|
282
|
+
outputs = outputs[..., :seq_len]
|
283
|
+
outputs = rearrange(outputs, 'b d n -> b n d')
|
284
|
+
return outputs
|
285
|
+
|
286
|
+
scan_fn = accelerate_scan_fn
|
287
|
+
else:
|
288
|
+
scan_fn = default_associative_scan
|
289
|
+
|
290
|
+
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
291
|
+
|
292
|
+
next_momentum = TensorDict()
|
293
|
+
updates = TensorDict()
|
294
|
+
|
295
|
+
for param_name, surprise in surprises.items():
|
296
|
+
|
297
|
+
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
|
298
|
+
|
299
|
+
# derive momentum with associative scan - eq (10)
|
300
|
+
|
301
|
+
momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
|
302
|
+
|
303
|
+
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
304
|
+
|
305
|
+
update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
|
306
|
+
|
307
|
+
updates[param_name] = inverse_pack(update)
|
308
|
+
next_momentum[param_name] = inverse_pack(momentum)
|
309
|
+
|
310
|
+
# compute the next weight per batch
|
311
|
+
|
312
|
+
last_update = updates.apply(lambda t: t[:, -1])
|
313
|
+
|
314
|
+
next_state = (curr_weights + last_update, next_momentum)
|
315
|
+
|
316
|
+
return updates, next_state
|
317
|
+
|
318
|
+
def retrieve_memories(
|
319
|
+
self,
|
320
|
+
seq,
|
321
|
+
past_weights: dict[str, Tensor] | None = None,
|
322
|
+
):
|
323
|
+
chunk_size = self.chunk_size
|
324
|
+
seq_len = seq.shape[1]
|
325
|
+
|
326
|
+
seq = self.retrieve_norm(seq)
|
327
|
+
|
328
|
+
assert seq_len > chunk_size
|
329
|
+
|
330
|
+
seq = seq[:, chunk_size:]
|
331
|
+
curtailed_seq_len = seq.shape[-2]
|
332
|
+
|
333
|
+
next_seq_len = round_up_multiple(curtailed_seq_len + 1, chunk_size)
|
334
|
+
|
335
|
+
padding = next_seq_len - curtailed_seq_len
|
336
|
+
|
337
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
338
|
+
|
339
|
+
# the parameters of the memory model stores the memories of the key / values
|
340
|
+
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
341
|
+
|
342
|
+
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
343
|
+
|
344
|
+
if exists(past_weights):
|
345
|
+
past_weights = TensorDict(past_weights)
|
346
|
+
assert past_weights.keys() == curr_weights.keys()
|
347
|
+
|
348
|
+
curr_weights = curr_weights + past_weights
|
349
|
+
|
350
|
+
# sequence Float['b n d'] to queries
|
351
|
+
|
352
|
+
queries = self.to_queries(seq)
|
353
|
+
|
354
|
+
# maybe multihead
|
355
|
+
|
356
|
+
queries = self.split_heads(queries)
|
357
|
+
|
358
|
+
batch = queries.shape[0]
|
359
|
+
|
360
|
+
# fetch values from memory model
|
361
|
+
|
362
|
+
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
363
|
+
queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
|
364
|
+
|
365
|
+
# forward functional call
|
366
|
+
|
367
|
+
values = functional_call(self.memory_model, dict(curr_weights), queries)
|
368
|
+
|
369
|
+
# reconstitute batch dimension
|
370
|
+
|
371
|
+
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
372
|
+
|
373
|
+
# maybe merge heads and combine
|
374
|
+
|
375
|
+
values = self.merge_heads(values)
|
376
|
+
|
377
|
+
values = self.combine_heads(values)
|
378
|
+
|
379
|
+
# post norm, somehow could not stabilize this without it, not in paper
|
380
|
+
|
381
|
+
values = self.post_rmsnorm(values)
|
382
|
+
|
383
|
+
# restore
|
384
|
+
|
385
|
+
values = pad_at_dim(values, (chunk_size, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
|
386
|
+
values = values[:, :-padding]
|
387
|
+
|
388
|
+
return values
|
389
|
+
|
390
|
+
def forward(
|
391
|
+
self,
|
392
|
+
seq,
|
393
|
+
store_seq = None,
|
394
|
+
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
395
|
+
return_next_memories = False
|
396
|
+
):
|
397
|
+
batch, seq_len = seq.shape[:2]
|
398
|
+
|
399
|
+
if seq_len <= self.chunk_size:
|
400
|
+
return torch.zeros_like(seq)
|
401
|
+
|
402
|
+
if exists(past_state):
|
403
|
+
past_state = tuple(TensorDict(d) for d in past_state)
|
404
|
+
|
405
|
+
if not exists(past_state):
|
406
|
+
past_state = self.init_weights_and_momentum()
|
407
|
+
|
408
|
+
store_seq = default(store_seq, seq)
|
409
|
+
|
410
|
+
updates, next_memories = self.store_memories(store_seq, past_state)
|
411
|
+
|
412
|
+
past_weights, _ = past_state
|
413
|
+
|
414
|
+
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
415
|
+
|
416
|
+
if not return_next_memories:
|
417
|
+
return retrieved
|
418
|
+
|
419
|
+
return retrieved, next_memories
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|