titans-pytorch 0.0.14__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- titans_pytorch/__init__.py +3 -0
- titans_pytorch/associative_scan.py +90 -0
- titans_pytorch/titans.py +408 -0
- titans_pytorch-0.0.14.dist-info/METADATA +111 -0
- titans_pytorch-0.0.14.dist-info/RECORD +7 -0
- titans_pytorch-0.0.14.dist-info/WHEEL +4 -0
- titans_pytorch-0.0.14.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,90 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Callable
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch import Tensor
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
# taken from S5-pytorch repository
|
9
|
+
# https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
|
10
|
+
|
11
|
+
# helper functions
|
12
|
+
|
13
|
+
def pad_at_dim(t, pad, dim = -1, value = 0.):
|
14
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
15
|
+
zeros = ((0, 0) * dims_from_right)
|
16
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
17
|
+
|
18
|
+
# the operator that is needed
|
19
|
+
|
20
|
+
@torch.jit.script
|
21
|
+
def binary_operator(
|
22
|
+
a: tuple[Tensor, Tensor],
|
23
|
+
b: tuple[Tensor, Tensor]
|
24
|
+
):
|
25
|
+
a_i, kv_i = a
|
26
|
+
a_j, kv_j = b
|
27
|
+
return a_j * a_i, torch.addcmul(kv_j, a_j, kv_i)
|
28
|
+
|
29
|
+
# Pytorch impl. of jax.lax.associative_scan
|
30
|
+
# made specifically for axis of 1 (sequence of tokens for autoregressive modeling)
|
31
|
+
|
32
|
+
def associative_scan(
|
33
|
+
operator: Callable,
|
34
|
+
elems: tuple[Tensor, Tensor]
|
35
|
+
):
|
36
|
+
num_elems = int(elems[0].shape[1])
|
37
|
+
|
38
|
+
if not all(int(elem.shape[1]) == num_elems for elem in elems[1:]):
|
39
|
+
raise ValueError('Array inputs to associative_scan must have the same '
|
40
|
+
'first dimension. (saw: {})'
|
41
|
+
.format([elem.shape for elem in elems]))
|
42
|
+
|
43
|
+
def _scan(elems):
|
44
|
+
"""Perform scan on `elems`."""
|
45
|
+
num_elems = elems[0].shape[1]
|
46
|
+
|
47
|
+
if num_elems < 2:
|
48
|
+
return elems
|
49
|
+
|
50
|
+
# Combine adjacent pairs of elements.
|
51
|
+
|
52
|
+
reduced_elems = operator(
|
53
|
+
[elem[:, :-1:2] for elem in elems],
|
54
|
+
[elem[:, 1::2] for elem in elems])
|
55
|
+
|
56
|
+
# Recursively compute scan for partially reduced tensors.
|
57
|
+
|
58
|
+
odd_elems = _scan(reduced_elems)
|
59
|
+
|
60
|
+
if num_elems % 2 == 0:
|
61
|
+
even_elems = operator(
|
62
|
+
[e[:, :-1] for e in odd_elems],
|
63
|
+
[e[:, 2::2] for e in elems])
|
64
|
+
else:
|
65
|
+
even_elems = operator(
|
66
|
+
odd_elems,
|
67
|
+
[e[:, 2::2] for e in elems])
|
68
|
+
|
69
|
+
# The first element of a scan is the same as the first element
|
70
|
+
# of the original `elems`.
|
71
|
+
|
72
|
+
even_elems = [
|
73
|
+
torch.cat([elem[:, :1], result], dim=1)
|
74
|
+
for (elem, result) in zip(elems, even_elems)]
|
75
|
+
|
76
|
+
return list(map(_interleave, even_elems, odd_elems))
|
77
|
+
|
78
|
+
return _scan(elems)
|
79
|
+
|
80
|
+
def _interleave(a, b):
|
81
|
+
a_axis_len, b_axis_len = a.shape[1], b.shape[1]
|
82
|
+
output_axis_len = a_axis_len + b_axis_len
|
83
|
+
|
84
|
+
if (a_axis_len == (b_axis_len + 1)):
|
85
|
+
b = pad_at_dim(b, (0, 1), dim = 1)
|
86
|
+
|
87
|
+
stacked = torch.stack([a, b], dim=2)
|
88
|
+
interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
|
89
|
+
|
90
|
+
return interleaved[:, :output_axis_len]
|
titans_pytorch/titans.py
ADDED
@@ -0,0 +1,408 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import math
|
3
|
+
from functools import partial
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch import nn, Tensor
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from torch.nn import Linear, Module
|
9
|
+
from torch.func import functional_call, vmap, grad_and_value
|
10
|
+
|
11
|
+
from tensordict import TensorDict
|
12
|
+
|
13
|
+
from titans_pytorch.associative_scan import (
|
14
|
+
associative_scan,
|
15
|
+
binary_operator,
|
16
|
+
pad_at_dim
|
17
|
+
)
|
18
|
+
|
19
|
+
import einx
|
20
|
+
from einops import rearrange, pack, unpack
|
21
|
+
from einops.layers.torch import Rearrange, Reduce
|
22
|
+
|
23
|
+
"""
|
24
|
+
ein notation:
|
25
|
+
b - batch
|
26
|
+
n - sequence
|
27
|
+
d - feature dimension
|
28
|
+
c - intra-chunk
|
29
|
+
"""
|
30
|
+
|
31
|
+
# constants
|
32
|
+
|
33
|
+
LinearNoBias = partial(Linear, bias = False)
|
34
|
+
|
35
|
+
# functions
|
36
|
+
|
37
|
+
def exists(v):
|
38
|
+
return v is not None
|
39
|
+
|
40
|
+
def default(v, d):
|
41
|
+
return v if exists(v) else d
|
42
|
+
|
43
|
+
def round_down_multiple(seq, mult):
|
44
|
+
return seq // mult * mult
|
45
|
+
|
46
|
+
def round_up_multiple(seq, mult):
|
47
|
+
return math.ceil(seq / mult) * mult
|
48
|
+
|
49
|
+
def pack_one_with_inverse(t, pattern):
|
50
|
+
packed, packed_shape = pack([t], pattern)
|
51
|
+
|
52
|
+
def inverse(out, inv_pattern = None):
|
53
|
+
inv_pattern = default(inv_pattern, pattern)
|
54
|
+
return unpack(out, packed_shape, inv_pattern)[0]
|
55
|
+
|
56
|
+
return packed, inverse
|
57
|
+
|
58
|
+
# classes
|
59
|
+
|
60
|
+
class MLP(Module):
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
dim,
|
64
|
+
depth
|
65
|
+
):
|
66
|
+
super().__init__()
|
67
|
+
self.weights = nn.ParameterList([nn.Parameter(torch.randn(dim, dim)) for _ in range(depth)])
|
68
|
+
|
69
|
+
def forward(
|
70
|
+
self,
|
71
|
+
x
|
72
|
+
):
|
73
|
+
for ind, weight in enumerate(self.weights):
|
74
|
+
is_first = ind == 0
|
75
|
+
|
76
|
+
if not is_first:
|
77
|
+
x = F.silu(x)
|
78
|
+
|
79
|
+
x = x @ weight
|
80
|
+
|
81
|
+
return x
|
82
|
+
|
83
|
+
# main neural memory
|
84
|
+
|
85
|
+
def default_loss_fn(pred, target):
|
86
|
+
return (pred - target).pow(2).mean(dim = -1).sum()
|
87
|
+
|
88
|
+
class NeuralMemory(Module):
|
89
|
+
def __init__(
|
90
|
+
self,
|
91
|
+
dim,
|
92
|
+
chunk_size = 1,
|
93
|
+
dim_head = None,
|
94
|
+
heads = 1,
|
95
|
+
model: Module | None = None,
|
96
|
+
store_memory_loss_fn: Callable = default_loss_fn,
|
97
|
+
pre_rmsnorm = True,
|
98
|
+
post_rmsnorm = True,
|
99
|
+
use_accelerated_scan = False,
|
100
|
+
default_mlp_kwargs: dict = dict(
|
101
|
+
depth = 4
|
102
|
+
)
|
103
|
+
):
|
104
|
+
super().__init__()
|
105
|
+
|
106
|
+
# norms
|
107
|
+
|
108
|
+
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
109
|
+
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
110
|
+
|
111
|
+
self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
|
112
|
+
|
113
|
+
# maybe multi-headed
|
114
|
+
|
115
|
+
dim_head = default(dim_head, dim)
|
116
|
+
dim_inner = dim_head * heads
|
117
|
+
|
118
|
+
self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
|
119
|
+
self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
|
120
|
+
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
121
|
+
|
122
|
+
# memory mlp
|
123
|
+
|
124
|
+
if not exists(model):
|
125
|
+
model = MLP(dim_head, **default_mlp_kwargs)
|
126
|
+
|
127
|
+
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
128
|
+
|
129
|
+
# the memory is the weights of the model
|
130
|
+
|
131
|
+
self.memory_model = model
|
132
|
+
|
133
|
+
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
134
|
+
|
135
|
+
self.chunk_size = chunk_size
|
136
|
+
|
137
|
+
# prepare function for per sample gradients from model above, using torch.func
|
138
|
+
|
139
|
+
def forward_and_loss(params, inputs, target):
|
140
|
+
pred = functional_call(self.memory_model, params, inputs)
|
141
|
+
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
142
|
+
return loss
|
143
|
+
|
144
|
+
self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
|
145
|
+
|
146
|
+
# queries for retrieving from the model
|
147
|
+
|
148
|
+
self.to_queries = LinearNoBias(dim, dim_inner)
|
149
|
+
|
150
|
+
# keys and values for storing to the model
|
151
|
+
|
152
|
+
self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
|
153
|
+
self.store_memory_loss_fn = store_memory_loss_fn
|
154
|
+
|
155
|
+
# learned adaptive learning rate and momentum
|
156
|
+
# todo - explore mlp layerwise learned lr / momentum
|
157
|
+
|
158
|
+
self.to_momentum = nn.Sequential(
|
159
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
160
|
+
LinearNoBias(dim, heads),
|
161
|
+
Rearrange('b n h -> (b h) n 1')
|
162
|
+
)
|
163
|
+
|
164
|
+
self.to_adaptive_step = nn.Sequential(
|
165
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
166
|
+
LinearNoBias(dim, heads),
|
167
|
+
Rearrange('b n h -> (b h) n')
|
168
|
+
)
|
169
|
+
|
170
|
+
# weight decay factor
|
171
|
+
|
172
|
+
self.to_decay_factor = nn.Sequential(
|
173
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
174
|
+
LinearNoBias(dim, heads),
|
175
|
+
Rearrange('b n h -> (b h) n 1')
|
176
|
+
)
|
177
|
+
|
178
|
+
# maybe use accelerated scan
|
179
|
+
|
180
|
+
self.use_accelerated_scan = use_accelerated_scan
|
181
|
+
|
182
|
+
def init_weights_and_momentum(self):
|
183
|
+
params = TensorDict(dict(self.memory_model.named_parameters()))
|
184
|
+
|
185
|
+
init_weights = params.clone().zero_()
|
186
|
+
init_momentum = params.clone().zero_()
|
187
|
+
|
188
|
+
return init_weights, init_momentum
|
189
|
+
|
190
|
+
def store_memories(
|
191
|
+
self,
|
192
|
+
seq,
|
193
|
+
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
|
194
|
+
):
|
195
|
+
|
196
|
+
seq = self.store_norm(seq)
|
197
|
+
|
198
|
+
# curtail sequence by multiple of the chunk size
|
199
|
+
# only a complete chunk of the sequence provides the memory for the next chunk
|
200
|
+
|
201
|
+
seq_len, chunk_size = seq.shape[-2], self.chunk_size
|
202
|
+
round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
|
203
|
+
|
204
|
+
seq = seq[:, :round_down_seq_len]
|
205
|
+
|
206
|
+
# curr weights + past weights, in the case that the initial weights are learned
|
207
|
+
|
208
|
+
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
209
|
+
|
210
|
+
past_state = tuple(TensorDict(d) for d in past_state)
|
211
|
+
past_weights, past_momentum = past_state
|
212
|
+
|
213
|
+
curr_weights = curr_weights + past_weights
|
214
|
+
|
215
|
+
# pack batch and sequence dimension
|
216
|
+
|
217
|
+
adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
|
218
|
+
|
219
|
+
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
220
|
+
decay_factor = self.to_decay_factor(seq).sigmoid()
|
221
|
+
|
222
|
+
# keys and values
|
223
|
+
|
224
|
+
keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
|
225
|
+
|
226
|
+
# maybe multi head
|
227
|
+
|
228
|
+
keys, values = map(self.split_heads, (keys, values))
|
229
|
+
|
230
|
+
batch = keys.shape[0]
|
231
|
+
|
232
|
+
# take care of chunking
|
233
|
+
|
234
|
+
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
|
235
|
+
|
236
|
+
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
237
|
+
|
238
|
+
grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
|
239
|
+
|
240
|
+
grads = TensorDict(grads)
|
241
|
+
|
242
|
+
# restore batch and sequence dimension
|
243
|
+
|
244
|
+
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
245
|
+
|
246
|
+
# multiply gradients with learned adaptive step size
|
247
|
+
|
248
|
+
surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
|
249
|
+
|
250
|
+
# determine scan function
|
251
|
+
|
252
|
+
def default_associative_scan(gates, inputs):
|
253
|
+
_, outputs = associative_scan(binary_operator, (gates, inputs))
|
254
|
+
return outputs
|
255
|
+
|
256
|
+
if self.use_accelerated_scan:
|
257
|
+
from accelerated_scan.triton import scan as triton_scan
|
258
|
+
from accelerated_scan.warp import scan as warp_scan
|
259
|
+
|
260
|
+
scan = triton_scan if seq.is_cuda else warp_scan
|
261
|
+
|
262
|
+
def accelerate_scan_fn(gates, inputs):
|
263
|
+
gates = gates.expand_as(inputs)
|
264
|
+
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
265
|
+
|
266
|
+
seq_len = gates.shape[-1]
|
267
|
+
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
268
|
+
|
269
|
+
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
270
|
+
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
271
|
+
|
272
|
+
outputs = scan(gates, inputs)
|
273
|
+
|
274
|
+
outputs = outputs[..., :seq_len]
|
275
|
+
outputs = rearrange(outputs, 'b d n -> b n d')
|
276
|
+
return outputs
|
277
|
+
|
278
|
+
scan_fn = accelerate_scan_fn
|
279
|
+
else:
|
280
|
+
scan_fn = default_associative_scan
|
281
|
+
|
282
|
+
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
283
|
+
|
284
|
+
next_momentum = TensorDict()
|
285
|
+
updates = TensorDict()
|
286
|
+
|
287
|
+
for param_name, surprise in surprises.items():
|
288
|
+
|
289
|
+
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
|
290
|
+
|
291
|
+
# derive momentum with associative scan - eq (10)
|
292
|
+
|
293
|
+
momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
|
294
|
+
|
295
|
+
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
296
|
+
|
297
|
+
update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
|
298
|
+
|
299
|
+
updates[param_name] = inverse_pack(update)
|
300
|
+
next_momentum[param_name] = inverse_pack(momentum)
|
301
|
+
|
302
|
+
# compute the next weight per batch
|
303
|
+
|
304
|
+
last_update = updates.apply(lambda t: t[:, -1])
|
305
|
+
|
306
|
+
next_state = (curr_weights + last_update, next_momentum)
|
307
|
+
|
308
|
+
return updates, next_state, aux_store_loss.mean() / chunk_size
|
309
|
+
|
310
|
+
def retrieve_memories(
|
311
|
+
self,
|
312
|
+
seq,
|
313
|
+
past_weights: dict[str, Tensor] | None = None,
|
314
|
+
):
|
315
|
+
chunk_size = self.chunk_size
|
316
|
+
seq_len = seq.shape[1]
|
317
|
+
|
318
|
+
seq = self.retrieve_norm(seq)
|
319
|
+
|
320
|
+
assert seq_len >= chunk_size
|
321
|
+
|
322
|
+
seq = seq[:, (chunk_size - 1):]
|
323
|
+
curtailed_seq_len = seq.shape[-2]
|
324
|
+
|
325
|
+
next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
|
326
|
+
|
327
|
+
padding = next_seq_len - curtailed_seq_len
|
328
|
+
|
329
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
330
|
+
|
331
|
+
# the parameters of the memory model stores the memories of the key / values
|
332
|
+
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
333
|
+
|
334
|
+
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
335
|
+
|
336
|
+
if exists(past_weights):
|
337
|
+
past_weights = TensorDict(past_weights)
|
338
|
+
assert past_weights.keys() == curr_weights.keys()
|
339
|
+
|
340
|
+
curr_weights = curr_weights + past_weights
|
341
|
+
|
342
|
+
# sequence Float['b n d'] to queries
|
343
|
+
|
344
|
+
queries = self.to_queries(seq)
|
345
|
+
|
346
|
+
# maybe multihead
|
347
|
+
|
348
|
+
queries = self.split_heads(queries)
|
349
|
+
|
350
|
+
batch = queries.shape[0]
|
351
|
+
|
352
|
+
# fetch values from memory model
|
353
|
+
|
354
|
+
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
355
|
+
queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
|
356
|
+
|
357
|
+
# forward functional call
|
358
|
+
|
359
|
+
values = functional_call(self.memory_model, dict(curr_weights), queries)
|
360
|
+
|
361
|
+
# reconstitute batch dimension
|
362
|
+
|
363
|
+
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
364
|
+
|
365
|
+
# maybe merge heads and combine
|
366
|
+
|
367
|
+
values = self.merge_heads(values)
|
368
|
+
|
369
|
+
values = self.combine_heads(values)
|
370
|
+
|
371
|
+
# post norm, somehow could not stabilize this without it, not in paper
|
372
|
+
|
373
|
+
values = self.post_rmsnorm(values)
|
374
|
+
|
375
|
+
# restore
|
376
|
+
|
377
|
+
values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
|
378
|
+
values = values[:, :-padding]
|
379
|
+
|
380
|
+
return values
|
381
|
+
|
382
|
+
def forward(
|
383
|
+
self,
|
384
|
+
seq,
|
385
|
+
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
386
|
+
return_next_memories = False
|
387
|
+
):
|
388
|
+
batch, seq_len = seq.shape[:2]
|
389
|
+
|
390
|
+
if seq_len < self.chunk_size:
|
391
|
+
return torch.zeros_like(seq)
|
392
|
+
|
393
|
+
if exists(past_state):
|
394
|
+
past_state = tuple(TensorDict(d) for d in past_state)
|
395
|
+
|
396
|
+
if not exists(past_state):
|
397
|
+
past_state = self.init_weights_and_momentum()
|
398
|
+
|
399
|
+
updates, next_memories, aux_kv_mse_loss = self.store_memories(seq, past_state)
|
400
|
+
|
401
|
+
past_weights, _ = past_state
|
402
|
+
|
403
|
+
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
404
|
+
|
405
|
+
if not return_next_memories:
|
406
|
+
return retrieved
|
407
|
+
|
408
|
+
return retrieved, next_memories, aux_kv_mse_loss
|
@@ -0,0 +1,111 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: titans-pytorch
|
3
|
+
Version: 0.0.14
|
4
|
+
Summary: Titans
|
5
|
+
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
|
+
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
7
|
+
Author-email: Phil Wang <lucidrains@gmail.com>
|
8
|
+
License: MIT License
|
9
|
+
|
10
|
+
Copyright (c) 2025 Phil Wang
|
11
|
+
|
12
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
13
|
+
of this software and associated documentation files (the "Software"), to deal
|
14
|
+
in the Software without restriction, including without limitation the rights
|
15
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
16
|
+
copies of the Software, and to permit persons to whom the Software is
|
17
|
+
furnished to do so, subject to the following conditions:
|
18
|
+
|
19
|
+
The above copyright notice and this permission notice shall be included in all
|
20
|
+
copies or substantial portions of the Software.
|
21
|
+
|
22
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
23
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
24
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
25
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
26
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
27
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
28
|
+
SOFTWARE.
|
29
|
+
License-File: LICENSE
|
30
|
+
Keywords: artificial intelligence,deep learning,linear attention,neural memory module,test time training
|
31
|
+
Classifier: Development Status :: 4 - Beta
|
32
|
+
Classifier: Intended Audience :: Developers
|
33
|
+
Classifier: License :: OSI Approved :: MIT License
|
34
|
+
Classifier: Programming Language :: Python :: 3.9
|
35
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
|
+
Requires-Python: >=3.9
|
37
|
+
Requires-Dist: accelerated-scan>=0.2.0
|
38
|
+
Requires-Dist: einops>=0.8.0
|
39
|
+
Requires-Dist: einx>=0.3.0
|
40
|
+
Requires-Dist: ninja
|
41
|
+
Requires-Dist: tensordict
|
42
|
+
Requires-Dist: torch>=2.2
|
43
|
+
Provides-Extra: examples
|
44
|
+
Requires-Dist: local-attention>=1.10.1; extra == 'examples'
|
45
|
+
Requires-Dist: taylor-series-linear-attention; extra == 'examples'
|
46
|
+
Requires-Dist: tqdm; extra == 'examples'
|
47
|
+
Requires-Dist: wandb; extra == 'examples'
|
48
|
+
Provides-Extra: test
|
49
|
+
Requires-Dist: pytest; extra == 'test'
|
50
|
+
Description-Content-Type: text/markdown
|
51
|
+
|
52
|
+
<img src="./fig2.png" width="400px"></img>
|
53
|
+
|
54
|
+
<img src="./fig1.png" width="400px"></img>
|
55
|
+
|
56
|
+
## Titans - Pytorch (wip)
|
57
|
+
|
58
|
+
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
59
|
+
|
60
|
+
## Install
|
61
|
+
|
62
|
+
```bash
|
63
|
+
$ pip install titans-pytorch
|
64
|
+
```
|
65
|
+
|
66
|
+
## Usage
|
67
|
+
|
68
|
+
```python
|
69
|
+
import torch
|
70
|
+
from titans_pytorch import NeuralMemory
|
71
|
+
|
72
|
+
mem = NeuralMemory(
|
73
|
+
dim = 384,
|
74
|
+
chunk_size = 64,
|
75
|
+
pre_rmsnorm = True
|
76
|
+
).cuda()
|
77
|
+
|
78
|
+
seq = torch.randn(2, 1024, 384).cuda()
|
79
|
+
retrieved = mem(seq)
|
80
|
+
|
81
|
+
assert seq.shape == retrieved.shape
|
82
|
+
```
|
83
|
+
|
84
|
+
## Experiments
|
85
|
+
|
86
|
+
```bash
|
87
|
+
$ pip install .[examples]
|
88
|
+
```
|
89
|
+
|
90
|
+
For the SOTA linear attention, you will also need to run
|
91
|
+
|
92
|
+
```bash
|
93
|
+
$ pip install -r requirements.txt
|
94
|
+
```
|
95
|
+
|
96
|
+
Then modify `train.py` and run it to query nature
|
97
|
+
|
98
|
+
```bash
|
99
|
+
$ python train.py
|
100
|
+
```
|
101
|
+
|
102
|
+
## Citations
|
103
|
+
|
104
|
+
```bibtex
|
105
|
+
@inproceedings{Behrouz2024TitansLT,
|
106
|
+
title = {Titans: Learning to Memorize at Test Time},
|
107
|
+
author = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni},
|
108
|
+
year = {2024},
|
109
|
+
url = {https://api.semanticscholar.org/CorpusID:275212078}
|
110
|
+
}
|
111
|
+
```
|
@@ -0,0 +1,7 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/titans.py,sha256=O4WO2I6GhxyZobqbgfFx01saFEKUxhA0BBwt70m2yeQ,12306
|
4
|
+
titans_pytorch-0.0.14.dist-info/METADATA,sha256=HKDSJ3sWc54sN1_fOEYU7i5TjiQff49vcZG9G8EU6z4,3598
|
5
|
+
titans_pytorch-0.0.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
titans_pytorch-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
titans_pytorch-0.0.14.dist-info/RECORD,,
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2025 Phil Wang
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|