titans-pytorch 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +3 -0
- titans_pytorch/associative_scan.py +90 -0
- titans_pytorch/titans.py +408 -0
- titans_pytorch-0.0.14.dist-info/METADATA +111 -0
- titans_pytorch-0.0.14.dist-info/RECORD +7 -0
- titans_pytorch-0.0.14.dist-info/WHEEL +4 -0
- titans_pytorch-0.0.14.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,90 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Callable
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch import Tensor
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
# taken from S5-pytorch repository
|
9
|
+
# https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
|
10
|
+
|
11
|
+
# helper functions
|
12
|
+
|
13
|
+
def pad_at_dim(t, pad, dim = -1, value = 0.):
|
14
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
15
|
+
zeros = ((0, 0) * dims_from_right)
|
16
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
17
|
+
|
18
|
+
# the operator that is needed
|
19
|
+
|
20
|
+
@torch.jit.script
|
21
|
+
def binary_operator(
|
22
|
+
a: tuple[Tensor, Tensor],
|
23
|
+
b: tuple[Tensor, Tensor]
|
24
|
+
):
|
25
|
+
a_i, kv_i = a
|
26
|
+
a_j, kv_j = b
|
27
|
+
return a_j * a_i, torch.addcmul(kv_j, a_j, kv_i)
|
28
|
+
|
29
|
+
# Pytorch impl. of jax.lax.associative_scan
|
30
|
+
# made specifically for axis of 1 (sequence of tokens for autoregressive modeling)
|
31
|
+
|
32
|
+
def associative_scan(
|
33
|
+
operator: Callable,
|
34
|
+
elems: tuple[Tensor, Tensor]
|
35
|
+
):
|
36
|
+
num_elems = int(elems[0].shape[1])
|
37
|
+
|
38
|
+
if not all(int(elem.shape[1]) == num_elems for elem in elems[1:]):
|
39
|
+
raise ValueError('Array inputs to associative_scan must have the same '
|
40
|
+
'first dimension. (saw: {})'
|
41
|
+
.format([elem.shape for elem in elems]))
|
42
|
+
|
43
|
+
def _scan(elems):
|
44
|
+
"""Perform scan on `elems`."""
|
45
|
+
num_elems = elems[0].shape[1]
|
46
|
+
|
47
|
+
if num_elems < 2:
|
48
|
+
return elems
|
49
|
+
|
50
|
+
# Combine adjacent pairs of elements.
|
51
|
+
|
52
|
+
reduced_elems = operator(
|
53
|
+
[elem[:, :-1:2] for elem in elems],
|
54
|
+
[elem[:, 1::2] for elem in elems])
|
55
|
+
|
56
|
+
# Recursively compute scan for partially reduced tensors.
|
57
|
+
|
58
|
+
odd_elems = _scan(reduced_elems)
|
59
|
+
|
60
|
+
if num_elems % 2 == 0:
|
61
|
+
even_elems = operator(
|
62
|
+
[e[:, :-1] for e in odd_elems],
|
63
|
+
[e[:, 2::2] for e in elems])
|
64
|
+
else:
|
65
|
+
even_elems = operator(
|
66
|
+
odd_elems,
|
67
|
+
[e[:, 2::2] for e in elems])
|
68
|
+
|
69
|
+
# The first element of a scan is the same as the first element
|
70
|
+
# of the original `elems`.
|
71
|
+
|
72
|
+
even_elems = [
|
73
|
+
torch.cat([elem[:, :1], result], dim=1)
|
74
|
+
for (elem, result) in zip(elems, even_elems)]
|
75
|
+
|
76
|
+
return list(map(_interleave, even_elems, odd_elems))
|
77
|
+
|
78
|
+
return _scan(elems)
|
79
|
+
|
80
|
+
def _interleave(a, b):
|
81
|
+
a_axis_len, b_axis_len = a.shape[1], b.shape[1]
|
82
|
+
output_axis_len = a_axis_len + b_axis_len
|
83
|
+
|
84
|
+
if (a_axis_len == (b_axis_len + 1)):
|
85
|
+
b = pad_at_dim(b, (0, 1), dim = 1)
|
86
|
+
|
87
|
+
stacked = torch.stack([a, b], dim=2)
|
88
|
+
interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
|
89
|
+
|
90
|
+
return interleaved[:, :output_axis_len]
|
titans_pytorch/titans.py
ADDED
@@ -0,0 +1,408 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import math
|
3
|
+
from functools import partial
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch import nn, Tensor
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from torch.nn import Linear, Module
|
9
|
+
from torch.func import functional_call, vmap, grad_and_value
|
10
|
+
|
11
|
+
from tensordict import TensorDict
|
12
|
+
|
13
|
+
from titans_pytorch.associative_scan import (
|
14
|
+
associative_scan,
|
15
|
+
binary_operator,
|
16
|
+
pad_at_dim
|
17
|
+
)
|
18
|
+
|
19
|
+
import einx
|
20
|
+
from einops import rearrange, pack, unpack
|
21
|
+
from einops.layers.torch import Rearrange, Reduce
|
22
|
+
|
23
|
+
"""
|
24
|
+
ein notation:
|
25
|
+
b - batch
|
26
|
+
n - sequence
|
27
|
+
d - feature dimension
|
28
|
+
c - intra-chunk
|
29
|
+
"""
|
30
|
+
|
31
|
+
# constants
|
32
|
+
|
33
|
+
LinearNoBias = partial(Linear, bias = False)
|
34
|
+
|
35
|
+
# functions
|
36
|
+
|
37
|
+
def exists(v):
|
38
|
+
return v is not None
|
39
|
+
|
40
|
+
def default(v, d):
|
41
|
+
return v if exists(v) else d
|
42
|
+
|
43
|
+
def round_down_multiple(seq, mult):
|
44
|
+
return seq // mult * mult
|
45
|
+
|
46
|
+
def round_up_multiple(seq, mult):
|
47
|
+
return math.ceil(seq / mult) * mult
|
48
|
+
|
49
|
+
def pack_one_with_inverse(t, pattern):
|
50
|
+
packed, packed_shape = pack([t], pattern)
|
51
|
+
|
52
|
+
def inverse(out, inv_pattern = None):
|
53
|
+
inv_pattern = default(inv_pattern, pattern)
|
54
|
+
return unpack(out, packed_shape, inv_pattern)[0]
|
55
|
+
|
56
|
+
return packed, inverse
|
57
|
+
|
58
|
+
# classes
|
59
|
+
|
60
|
+
class MLP(Module):
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
dim,
|
64
|
+
depth
|
65
|
+
):
|
66
|
+
super().__init__()
|
67
|
+
self.weights = nn.ParameterList([nn.Parameter(torch.randn(dim, dim)) for _ in range(depth)])
|
68
|
+
|
69
|
+
def forward(
|
70
|
+
self,
|
71
|
+
x
|
72
|
+
):
|
73
|
+
for ind, weight in enumerate(self.weights):
|
74
|
+
is_first = ind == 0
|
75
|
+
|
76
|
+
if not is_first:
|
77
|
+
x = F.silu(x)
|
78
|
+
|
79
|
+
x = x @ weight
|
80
|
+
|
81
|
+
return x
|
82
|
+
|
83
|
+
# main neural memory
|
84
|
+
|
85
|
+
def default_loss_fn(pred, target):
|
86
|
+
return (pred - target).pow(2).mean(dim = -1).sum()
|
87
|
+
|
88
|
+
class NeuralMemory(Module):
|
89
|
+
def __init__(
|
90
|
+
self,
|
91
|
+
dim,
|
92
|
+
chunk_size = 1,
|
93
|
+
dim_head = None,
|
94
|
+
heads = 1,
|
95
|
+
model: Module | None = None,
|
96
|
+
store_memory_loss_fn: Callable = default_loss_fn,
|
97
|
+
pre_rmsnorm = True,
|
98
|
+
post_rmsnorm = True,
|
99
|
+
use_accelerated_scan = False,
|
100
|
+
default_mlp_kwargs: dict = dict(
|
101
|
+
depth = 4
|
102
|
+
)
|
103
|
+
):
|
104
|
+
super().__init__()
|
105
|
+
|
106
|
+
# norms
|
107
|
+
|
108
|
+
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
109
|
+
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
110
|
+
|
111
|
+
self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
|
112
|
+
|
113
|
+
# maybe multi-headed
|
114
|
+
|
115
|
+
dim_head = default(dim_head, dim)
|
116
|
+
dim_inner = dim_head * heads
|
117
|
+
|
118
|
+
self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
|
119
|
+
self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
|
120
|
+
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
121
|
+
|
122
|
+
# memory mlp
|
123
|
+
|
124
|
+
if not exists(model):
|
125
|
+
model = MLP(dim_head, **default_mlp_kwargs)
|
126
|
+
|
127
|
+
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
128
|
+
|
129
|
+
# the memory is the weights of the model
|
130
|
+
|
131
|
+
self.memory_model = model
|
132
|
+
|
133
|
+
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
134
|
+
|
135
|
+
self.chunk_size = chunk_size
|
136
|
+
|
137
|
+
# prepare function for per sample gradients from model above, using torch.func
|
138
|
+
|
139
|
+
def forward_and_loss(params, inputs, target):
|
140
|
+
pred = functional_call(self.memory_model, params, inputs)
|
141
|
+
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
142
|
+
return loss
|
143
|
+
|
144
|
+
self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
|
145
|
+
|
146
|
+
# queries for retrieving from the model
|
147
|
+
|
148
|
+
self.to_queries = LinearNoBias(dim, dim_inner)
|
149
|
+
|
150
|
+
# keys and values for storing to the model
|
151
|
+
|
152
|
+
self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
|
153
|
+
self.store_memory_loss_fn = store_memory_loss_fn
|
154
|
+
|
155
|
+
# learned adaptive learning rate and momentum
|
156
|
+
# todo - explore mlp layerwise learned lr / momentum
|
157
|
+
|
158
|
+
self.to_momentum = nn.Sequential(
|
159
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
160
|
+
LinearNoBias(dim, heads),
|
161
|
+
Rearrange('b n h -> (b h) n 1')
|
162
|
+
)
|
163
|
+
|
164
|
+
self.to_adaptive_step = nn.Sequential(
|
165
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
166
|
+
LinearNoBias(dim, heads),
|
167
|
+
Rearrange('b n h -> (b h) n')
|
168
|
+
)
|
169
|
+
|
170
|
+
# weight decay factor
|
171
|
+
|
172
|
+
self.to_decay_factor = nn.Sequential(
|
173
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
174
|
+
LinearNoBias(dim, heads),
|
175
|
+
Rearrange('b n h -> (b h) n 1')
|
176
|
+
)
|
177
|
+
|
178
|
+
# maybe use accelerated scan
|
179
|
+
|
180
|
+
self.use_accelerated_scan = use_accelerated_scan
|
181
|
+
|
182
|
+
def init_weights_and_momentum(self):
|
183
|
+
params = TensorDict(dict(self.memory_model.named_parameters()))
|
184
|
+
|
185
|
+
init_weights = params.clone().zero_()
|
186
|
+
init_momentum = params.clone().zero_()
|
187
|
+
|
188
|
+
return init_weights, init_momentum
|
189
|
+
|
190
|
+
def store_memories(
|
191
|
+
self,
|
192
|
+
seq,
|
193
|
+
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
|
194
|
+
):
|
195
|
+
|
196
|
+
seq = self.store_norm(seq)
|
197
|
+
|
198
|
+
# curtail sequence by multiple of the chunk size
|
199
|
+
# only a complete chunk of the sequence provides the memory for the next chunk
|
200
|
+
|
201
|
+
seq_len, chunk_size = seq.shape[-2], self.chunk_size
|
202
|
+
round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
|
203
|
+
|
204
|
+
seq = seq[:, :round_down_seq_len]
|
205
|
+
|
206
|
+
# curr weights + past weights, in the case that the initial weights are learned
|
207
|
+
|
208
|
+
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
209
|
+
|
210
|
+
past_state = tuple(TensorDict(d) for d in past_state)
|
211
|
+
past_weights, past_momentum = past_state
|
212
|
+
|
213
|
+
curr_weights = curr_weights + past_weights
|
214
|
+
|
215
|
+
# pack batch and sequence dimension
|
216
|
+
|
217
|
+
adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
|
218
|
+
|
219
|
+
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
220
|
+
decay_factor = self.to_decay_factor(seq).sigmoid()
|
221
|
+
|
222
|
+
# keys and values
|
223
|
+
|
224
|
+
keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
|
225
|
+
|
226
|
+
# maybe multi head
|
227
|
+
|
228
|
+
keys, values = map(self.split_heads, (keys, values))
|
229
|
+
|
230
|
+
batch = keys.shape[0]
|
231
|
+
|
232
|
+
# take care of chunking
|
233
|
+
|
234
|
+
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
|
235
|
+
|
236
|
+
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
237
|
+
|
238
|
+
grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
|
239
|
+
|
240
|
+
grads = TensorDict(grads)
|
241
|
+
|
242
|
+
# restore batch and sequence dimension
|
243
|
+
|
244
|
+
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
245
|
+
|
246
|
+
# multiply gradients with learned adaptive step size
|
247
|
+
|
248
|
+
surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
|
249
|
+
|
250
|
+
# determine scan function
|
251
|
+
|
252
|
+
def default_associative_scan(gates, inputs):
|
253
|
+
_, outputs = associative_scan(binary_operator, (gates, inputs))
|
254
|
+
return outputs
|
255
|
+
|
256
|
+
if self.use_accelerated_scan:
|
257
|
+
from accelerated_scan.triton import scan as triton_scan
|
258
|
+
from accelerated_scan.warp import scan as warp_scan
|
259
|
+
|
260
|
+
scan = triton_scan if seq.is_cuda else warp_scan
|
261
|
+
|
262
|
+
def accelerate_scan_fn(gates, inputs):
|
263
|
+
gates = gates.expand_as(inputs)
|
264
|
+
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
265
|
+
|
266
|
+
seq_len = gates.shape[-1]
|
267
|
+
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
268
|
+
|
269
|
+
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
270
|
+
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
271
|
+
|
272
|
+
outputs = scan(gates, inputs)
|
273
|
+
|
274
|
+
outputs = outputs[..., :seq_len]
|
275
|
+
outputs = rearrange(outputs, 'b d n -> b n d')
|
276
|
+
return outputs
|
277
|
+
|
278
|
+
scan_fn = accelerate_scan_fn
|
279
|
+
else:
|
280
|
+
scan_fn = default_associative_scan
|
281
|
+
|
282
|
+
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
283
|
+
|
284
|
+
next_momentum = TensorDict()
|
285
|
+
updates = TensorDict()
|
286
|
+
|
287
|
+
for param_name, surprise in surprises.items():
|
288
|
+
|
289
|
+
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
|
290
|
+
|
291
|
+
# derive momentum with associative scan - eq (10)
|
292
|
+
|
293
|
+
momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
|
294
|
+
|
295
|
+
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
296
|
+
|
297
|
+
update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
|
298
|
+
|
299
|
+
updates[param_name] = inverse_pack(update)
|
300
|
+
next_momentum[param_name] = inverse_pack(momentum)
|
301
|
+
|
302
|
+
# compute the next weight per batch
|
303
|
+
|
304
|
+
last_update = updates.apply(lambda t: t[:, -1])
|
305
|
+
|
306
|
+
next_state = (curr_weights + last_update, next_momentum)
|
307
|
+
|
308
|
+
return updates, next_state, aux_store_loss.mean() / chunk_size
|
309
|
+
|
310
|
+
def retrieve_memories(
|
311
|
+
self,
|
312
|
+
seq,
|
313
|
+
past_weights: dict[str, Tensor] | None = None,
|
314
|
+
):
|
315
|
+
chunk_size = self.chunk_size
|
316
|
+
seq_len = seq.shape[1]
|
317
|
+
|
318
|
+
seq = self.retrieve_norm(seq)
|
319
|
+
|
320
|
+
assert seq_len >= chunk_size
|
321
|
+
|
322
|
+
seq = seq[:, (chunk_size - 1):]
|
323
|
+
curtailed_seq_len = seq.shape[-2]
|
324
|
+
|
325
|
+
next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
|
326
|
+
|
327
|
+
padding = next_seq_len - curtailed_seq_len
|
328
|
+
|
329
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
330
|
+
|
331
|
+
# the parameters of the memory model stores the memories of the key / values
|
332
|
+
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
333
|
+
|
334
|
+
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
335
|
+
|
336
|
+
if exists(past_weights):
|
337
|
+
past_weights = TensorDict(past_weights)
|
338
|
+
assert past_weights.keys() == curr_weights.keys()
|
339
|
+
|
340
|
+
curr_weights = curr_weights + past_weights
|
341
|
+
|
342
|
+
# sequence Float['b n d'] to queries
|
343
|
+
|
344
|
+
queries = self.to_queries(seq)
|
345
|
+
|
346
|
+
# maybe multihead
|
347
|
+
|
348
|
+
queries = self.split_heads(queries)
|
349
|
+
|
350
|
+
batch = queries.shape[0]
|
351
|
+
|
352
|
+
# fetch values from memory model
|
353
|
+
|
354
|
+
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
355
|
+
queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
|
356
|
+
|
357
|
+
# forward functional call
|
358
|
+
|
359
|
+
values = functional_call(self.memory_model, dict(curr_weights), queries)
|
360
|
+
|
361
|
+
# reconstitute batch dimension
|
362
|
+
|
363
|
+
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
364
|
+
|
365
|
+
# maybe merge heads and combine
|
366
|
+
|
367
|
+
values = self.merge_heads(values)
|
368
|
+
|
369
|
+
values = self.combine_heads(values)
|
370
|
+
|
371
|
+
# post norm, somehow could not stabilize this without it, not in paper
|
372
|
+
|
373
|
+
values = self.post_rmsnorm(values)
|
374
|
+
|
375
|
+
# restore
|
376
|
+
|
377
|
+
values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
|
378
|
+
values = values[:, :-padding]
|
379
|
+
|
380
|
+
return values
|
381
|
+
|
382
|
+
def forward(
|
383
|
+
self,
|
384
|
+
seq,
|
385
|
+
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
386
|
+
return_next_memories = False
|
387
|
+
):
|
388
|
+
batch, seq_len = seq.shape[:2]
|
389
|
+
|
390
|
+
if seq_len < self.chunk_size:
|
391
|
+
return torch.zeros_like(seq)
|
392
|
+
|
393
|
+
if exists(past_state):
|
394
|
+
past_state = tuple(TensorDict(d) for d in past_state)
|
395
|
+
|
396
|
+
if not exists(past_state):
|
397
|
+
past_state = self.init_weights_and_momentum()
|
398
|
+
|
399
|
+
updates, next_memories, aux_kv_mse_loss = self.store_memories(seq, past_state)
|
400
|
+
|
401
|
+
past_weights, _ = past_state
|
402
|
+
|
403
|
+
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
404
|
+
|
405
|
+
if not return_next_memories:
|
406
|
+
return retrieved
|
407
|
+
|
408
|
+
return retrieved, next_memories, aux_kv_mse_loss
|
@@ -0,0 +1,111 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: titans-pytorch
|
3
|
+
Version: 0.0.14
|
4
|
+
Summary: Titans
|
5
|
+
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
|
+
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
7
|
+
Author-email: Phil Wang <lucidrains@gmail.com>
|
8
|
+
License: MIT License
|
9
|
+
|
10
|
+
Copyright (c) 2025 Phil Wang
|
11
|
+
|
12
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
13
|
+
of this software and associated documentation files (the "Software"), to deal
|
14
|
+
in the Software without restriction, including without limitation the rights
|
15
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
16
|
+
copies of the Software, and to permit persons to whom the Software is
|
17
|
+
furnished to do so, subject to the following conditions:
|
18
|
+
|
19
|
+
The above copyright notice and this permission notice shall be included in all
|
20
|
+
copies or substantial portions of the Software.
|
21
|
+
|
22
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
23
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
24
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
25
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
26
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
27
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
28
|
+
SOFTWARE.
|
29
|
+
License-File: LICENSE
|
30
|
+
Keywords: artificial intelligence,deep learning,linear attention,neural memory module,test time training
|
31
|
+
Classifier: Development Status :: 4 - Beta
|
32
|
+
Classifier: Intended Audience :: Developers
|
33
|
+
Classifier: License :: OSI Approved :: MIT License
|
34
|
+
Classifier: Programming Language :: Python :: 3.9
|
35
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
|
+
Requires-Python: >=3.9
|
37
|
+
Requires-Dist: accelerated-scan>=0.2.0
|
38
|
+
Requires-Dist: einops>=0.8.0
|
39
|
+
Requires-Dist: einx>=0.3.0
|
40
|
+
Requires-Dist: ninja
|
41
|
+
Requires-Dist: tensordict
|
42
|
+
Requires-Dist: torch>=2.2
|
43
|
+
Provides-Extra: examples
|
44
|
+
Requires-Dist: local-attention>=1.10.1; extra == 'examples'
|
45
|
+
Requires-Dist: taylor-series-linear-attention; extra == 'examples'
|
46
|
+
Requires-Dist: tqdm; extra == 'examples'
|
47
|
+
Requires-Dist: wandb; extra == 'examples'
|
48
|
+
Provides-Extra: test
|
49
|
+
Requires-Dist: pytest; extra == 'test'
|
50
|
+
Description-Content-Type: text/markdown
|
51
|
+
|
52
|
+
<img src="./fig2.png" width="400px"></img>
|
53
|
+
|
54
|
+
<img src="./fig1.png" width="400px"></img>
|
55
|
+
|
56
|
+
## Titans - Pytorch (wip)
|
57
|
+
|
58
|
+
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
59
|
+
|
60
|
+
## Install
|
61
|
+
|
62
|
+
```bash
|
63
|
+
$ pip install titans-pytorch
|
64
|
+
```
|
65
|
+
|
66
|
+
## Usage
|
67
|
+
|
68
|
+
```python
|
69
|
+
import torch
|
70
|
+
from titans_pytorch import NeuralMemory
|
71
|
+
|
72
|
+
mem = NeuralMemory(
|
73
|
+
dim = 384,
|
74
|
+
chunk_size = 64,
|
75
|
+
pre_rmsnorm = True
|
76
|
+
).cuda()
|
77
|
+
|
78
|
+
seq = torch.randn(2, 1024, 384).cuda()
|
79
|
+
retrieved = mem(seq)
|
80
|
+
|
81
|
+
assert seq.shape == retrieved.shape
|
82
|
+
```
|
83
|
+
|
84
|
+
## Experiments
|
85
|
+
|
86
|
+
```bash
|
87
|
+
$ pip install .[examples]
|
88
|
+
```
|
89
|
+
|
90
|
+
For the SOTA linear attention, you will also need to run
|
91
|
+
|
92
|
+
```bash
|
93
|
+
$ pip install -r requirements.txt
|
94
|
+
```
|
95
|
+
|
96
|
+
Then modify `train.py` and run it to query nature
|
97
|
+
|
98
|
+
```bash
|
99
|
+
$ python train.py
|
100
|
+
```
|
101
|
+
|
102
|
+
## Citations
|
103
|
+
|
104
|
+
```bibtex
|
105
|
+
@inproceedings{Behrouz2024TitansLT,
|
106
|
+
title = {Titans: Learning to Memorize at Test Time},
|
107
|
+
author = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni},
|
108
|
+
year = {2024},
|
109
|
+
url = {https://api.semanticscholar.org/CorpusID:275212078}
|
110
|
+
}
|
111
|
+
```
|
@@ -0,0 +1,7 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/titans.py,sha256=O4WO2I6GhxyZobqbgfFx01saFEKUxhA0BBwt70m2yeQ,12306
|
4
|
+
titans_pytorch-0.0.14.dist-info/METADATA,sha256=HKDSJ3sWc54sN1_fOEYU7i5TjiQff49vcZG9G8EU6z4,3598
|
5
|
+
titans_pytorch-0.0.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
titans_pytorch-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
titans_pytorch-0.0.14.dist-info/RECORD,,
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2025 Phil Wang
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|