titans-pytorch 0.4.6__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/neural_memory.py +35 -1
- {titans_pytorch-0.4.6.dist-info → titans_pytorch-0.4.8.dist-info}/METADATA +30 -2
- titans_pytorch-0.4.8.dist-info/RECORD +8 -0
- titans_pytorch/associative_scan.py +0 -178
- titans_pytorch-0.4.6.dist-info/RECORD +0 -9
- {titans_pytorch-0.4.6.dist-info → titans_pytorch-0.4.8.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.4.6.dist-info → titans_pytorch-0.4.8.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/neural_memory.py
CHANGED
@@ -15,7 +15,7 @@ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
15
15
|
|
16
16
|
from tensordict import TensorDict
|
17
17
|
|
18
|
-
from
|
18
|
+
from assoc_scan import AssocScan
|
19
19
|
|
20
20
|
from titans_pytorch.memory_models import(
|
21
21
|
MemoryMLP,
|
@@ -152,6 +152,30 @@ def softclamp_grad_norm(t, max_value):
|
|
152
152
|
t = t * (clamped_norm / norm)
|
153
153
|
return inverse(t)
|
154
154
|
|
155
|
+
# spectral norming the surprise update w/ newton schulz matrix iter
|
156
|
+
# Keller Jordan et al. from OSS w/ nanogpt, now being used for two works, Atlas and 'TTT done right'
|
157
|
+
|
158
|
+
def newtonschulz5(
|
159
|
+
t,
|
160
|
+
steps = 5,
|
161
|
+
eps = 1e-7,
|
162
|
+
coefs = (3.4445, -4.7750, 2.0315)
|
163
|
+
):
|
164
|
+
if t.ndim <= 3:
|
165
|
+
return t
|
166
|
+
|
167
|
+
t, inv_pack = pack_one_with_inverse(t, '* i j')
|
168
|
+
t = t / t.norm(dim = (-1, -2), keepdim = True).clamp(min = eps)
|
169
|
+
|
170
|
+
a, b, c = coefs
|
171
|
+
|
172
|
+
for _ in range(steps):
|
173
|
+
A = t @ t.transpose(-1, -2)
|
174
|
+
B = b * A + c * A @ A
|
175
|
+
t = a * t + B @ t
|
176
|
+
|
177
|
+
return inv_pack(t)
|
178
|
+
|
155
179
|
# multi head rmsnorm
|
156
180
|
|
157
181
|
class MultiheadRMSNorm(Module):
|
@@ -254,6 +278,7 @@ class NeuralMemory(Module):
|
|
254
278
|
init_momentum_bias = None,
|
255
279
|
init_decay_bias = None,
|
256
280
|
accept_weight_residual = False,
|
281
|
+
spectral_norm_surprises = False,
|
257
282
|
gated_transition = False,
|
258
283
|
mem_model_norm_add_residual = True, # by default, layernorm output and add residual as proposed in TTT paper, but could be removed
|
259
284
|
default_model_kwargs: dict = dict(
|
@@ -465,6 +490,10 @@ class NeuralMemory(Module):
|
|
465
490
|
|
466
491
|
self.max_grad_norm = max_grad_norm
|
467
492
|
|
493
|
+
# spectral norming the surprises before update, a la Muon from Jordan et al.
|
494
|
+
|
495
|
+
self.spectral_norm_surprises = spectral_norm_surprises
|
496
|
+
|
468
497
|
# weight decay factor
|
469
498
|
|
470
499
|
self.to_decay_factor = Sequential(
|
@@ -748,6 +777,11 @@ class NeuralMemory(Module):
|
|
748
777
|
else:
|
749
778
|
update = einsum(combine_momentums, momentums, 'o b n, o b n ... -> b n ...')
|
750
779
|
|
780
|
+
# maybe spectral norm surprises
|
781
|
+
|
782
|
+
if self.spectral_norm_surprises:
|
783
|
+
update = newtonschulz5(update)
|
784
|
+
|
751
785
|
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
752
786
|
|
753
787
|
update = self.assoc_scan(1. - decay_factor, update, prev = last_update, remove_prev = False)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.4.
|
3
|
+
Version: 0.4.8
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -34,7 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
34
34
|
Classifier: Programming Language :: Python :: 3.9
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
36
|
Requires-Python: >=3.9
|
37
|
-
Requires-Dist:
|
37
|
+
Requires-Dist: assoc-scan
|
38
38
|
Requires-Dist: axial-positional-embedding>=0.3.10
|
39
39
|
Requires-Dist: einops>=0.8.0
|
40
40
|
Requires-Dist: einx>=0.3.0
|
@@ -207,3 +207,31 @@ $ python train_mac.py
|
|
207
207
|
url = {https://arxiv.org/abs/2501.12352},
|
208
208
|
}
|
209
209
|
```
|
210
|
+
|
211
|
+
```bibtex
|
212
|
+
@misc{jordan2024muon,
|
213
|
+
author = {Keller Jordan and Yuchen Jin and Vlado Boza and Jiacheng You and
|
214
|
+
Franz Cesista and Laker Newhouse and Jeremy Bernstein},
|
215
|
+
title = {Muon: An optimizer for hidden layers in neural networks},
|
216
|
+
year = {2024},
|
217
|
+
url = {https://kellerjordan.github.io/posts/muon/}
|
218
|
+
}
|
219
|
+
```
|
220
|
+
|
221
|
+
```bibtex
|
222
|
+
@inproceedings{Zhang2025TestTimeTD,
|
223
|
+
title = {Test-Time Training Done Right},
|
224
|
+
author = {Tianyuan Zhang and Sai Bi and Yicong Hong and Kai Zhang and Fujun Luan and Songlin Yang and Kalyan Sunkavalli and William T. Freeman and Hao Tan},
|
225
|
+
year = {2025},
|
226
|
+
url = {https://api.semanticscholar.org/CorpusID:279071244}
|
227
|
+
}
|
228
|
+
```
|
229
|
+
|
230
|
+
```bibtex
|
231
|
+
@inproceedings{Behrouz2025ATLASLT,
|
232
|
+
title = {ATLAS: Learning to Optimally Memorize the Context at Test Time},
|
233
|
+
author = {Ali Behrouz and Ze-Minghui Li and Praneeth Kacham and Majid Daliri and Yuan Deng and Peilin Zhong and Meisam Razaviyayn and Vahab S. Mirrokni},
|
234
|
+
year = {2025},
|
235
|
+
url = {https://api.semanticscholar.org/CorpusID:278996373}
|
236
|
+
}
|
237
|
+
```
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=sVTOuRUkaIYabFExdLY6s1qXm1UwHHz_J19H8ZV-X74,338
|
2
|
+
titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
|
3
|
+
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
4
|
+
titans_pytorch/neural_memory.py,sha256=ypWXN8koY8pXt7IvlcMR1QM7cYJnWK_iYLEHy2pjx88,34277
|
5
|
+
titans_pytorch-0.4.8.dist-info/METADATA,sha256=BbhF0oiPGdcgxrBGJziZvbvXUmS5lVlGVxpUGPmP0O8,7873
|
6
|
+
titans_pytorch-0.4.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.4.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.4.8.dist-info/RECORD,,
|
@@ -1,178 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
import math
|
3
|
-
from typing import Callable
|
4
|
-
|
5
|
-
import torch
|
6
|
-
from torch import Tensor
|
7
|
-
from torch.nn import Module
|
8
|
-
import torch.nn.functional as F
|
9
|
-
|
10
|
-
from einops import rearrange, repeat, reduce, pack, unpack
|
11
|
-
|
12
|
-
# taken from S5-pytorch repository
|
13
|
-
# https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
|
14
|
-
|
15
|
-
# helper functions
|
16
|
-
|
17
|
-
def exists(v):
|
18
|
-
return v is not None
|
19
|
-
|
20
|
-
def default(*args):
|
21
|
-
for arg in args:
|
22
|
-
if exists(arg):
|
23
|
-
return arg
|
24
|
-
return None
|
25
|
-
|
26
|
-
def pad_at_dim(t, pad, dim = -1, value = 0.):
|
27
|
-
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
28
|
-
zeros = ((0, 0) * dims_from_right)
|
29
|
-
return F.pad(t, (*zeros, *pad), value = value)
|
30
|
-
|
31
|
-
def pack_one_with_inverse(t, pattern):
|
32
|
-
packed, packed_shape = pack([t], pattern)
|
33
|
-
|
34
|
-
def inverse(out, inv_pattern = None):
|
35
|
-
inv_pattern = default(inv_pattern, pattern)
|
36
|
-
return unpack(out, packed_shape, inv_pattern)[0]
|
37
|
-
|
38
|
-
return packed, inverse
|
39
|
-
|
40
|
-
# the operator that is needed
|
41
|
-
|
42
|
-
@torch.jit.script
|
43
|
-
def binary_operator(
|
44
|
-
a: tuple[Tensor, Tensor],
|
45
|
-
b: tuple[Tensor, Tensor]
|
46
|
-
):
|
47
|
-
a_i, kv_i = a
|
48
|
-
a_j, kv_j = b
|
49
|
-
return a_j * a_i, torch.addcmul(kv_j, a_j, kv_i)
|
50
|
-
|
51
|
-
# Pytorch impl. of jax.lax.associative_scan
|
52
|
-
# made specifically for axis of 1 (sequence of tokens for autoregressive modeling)
|
53
|
-
|
54
|
-
def associative_scan(
|
55
|
-
operator: Callable,
|
56
|
-
elems: tuple[Tensor, Tensor]
|
57
|
-
):
|
58
|
-
num_elems = int(elems[0].shape[1])
|
59
|
-
|
60
|
-
if not all(int(elem.shape[1]) == num_elems for elem in elems[1:]):
|
61
|
-
raise ValueError('Array inputs to associative_scan must have the same '
|
62
|
-
'first dimension. (saw: {})'
|
63
|
-
.format([elem.shape for elem in elems]))
|
64
|
-
|
65
|
-
def _scan(elems):
|
66
|
-
"""Perform scan on `elems`."""
|
67
|
-
num_elems = elems[0].shape[1]
|
68
|
-
|
69
|
-
if num_elems < 2:
|
70
|
-
return elems
|
71
|
-
|
72
|
-
# Combine adjacent pairs of elements.
|
73
|
-
|
74
|
-
reduced_elems = operator(
|
75
|
-
[elem[:, :-1:2] for elem in elems],
|
76
|
-
[elem[:, 1::2] for elem in elems])
|
77
|
-
|
78
|
-
# Recursively compute scan for partially reduced tensors.
|
79
|
-
|
80
|
-
odd_elems = _scan(reduced_elems)
|
81
|
-
|
82
|
-
if num_elems % 2 == 0:
|
83
|
-
even_elems = operator(
|
84
|
-
[e[:, :-1] for e in odd_elems],
|
85
|
-
[e[:, 2::2] for e in elems])
|
86
|
-
else:
|
87
|
-
even_elems = operator(
|
88
|
-
odd_elems,
|
89
|
-
[e[:, 2::2] for e in elems])
|
90
|
-
|
91
|
-
# The first element of a scan is the same as the first element
|
92
|
-
# of the original `elems`.
|
93
|
-
|
94
|
-
even_elems = [
|
95
|
-
torch.cat([elem[:, :1], result], dim=1)
|
96
|
-
for (elem, result) in zip(elems, even_elems)]
|
97
|
-
|
98
|
-
return list(map(_interleave, even_elems, odd_elems))
|
99
|
-
|
100
|
-
return _scan(elems)
|
101
|
-
|
102
|
-
def _interleave(a, b):
|
103
|
-
a_axis_len, b_axis_len = a.shape[1], b.shape[1]
|
104
|
-
output_axis_len = a_axis_len + b_axis_len
|
105
|
-
|
106
|
-
if (a_axis_len == (b_axis_len + 1)):
|
107
|
-
b = pad_at_dim(b, (0, 1), dim = 1)
|
108
|
-
|
109
|
-
stacked = torch.stack([a, b], dim=2)
|
110
|
-
interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
|
111
|
-
|
112
|
-
return interleaved[:, :output_axis_len]
|
113
|
-
|
114
|
-
# associative scan wrapper around naive and accelerated version
|
115
|
-
|
116
|
-
class AssocScan(Module):
|
117
|
-
def __init__(
|
118
|
-
self,
|
119
|
-
use_accelerated = False
|
120
|
-
):
|
121
|
-
super().__init__()
|
122
|
-
self.use_accelerated = use_accelerated
|
123
|
-
|
124
|
-
def forward(
|
125
|
-
self,
|
126
|
-
gates,
|
127
|
-
inputs,
|
128
|
-
prev = None,
|
129
|
-
remove_prev = None
|
130
|
-
):
|
131
|
-
remove_prev = default(remove_prev, exists(prev))
|
132
|
-
|
133
|
-
inputs, inverse_pack_weight_shape = pack_one_with_inverse(inputs, 'b n *')
|
134
|
-
gates, _ = pack_one_with_inverse(gates, 'b n *')
|
135
|
-
|
136
|
-
if exists(prev):
|
137
|
-
prev, _ = pack_one_with_inverse(prev, 'b *')
|
138
|
-
|
139
|
-
if exists(prev):
|
140
|
-
inputs, _ = pack([prev, inputs], 'b * d')
|
141
|
-
gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
|
142
|
-
|
143
|
-
if not self.use_accelerated:
|
144
|
-
_, out = associative_scan(binary_operator, (gates, inputs))
|
145
|
-
|
146
|
-
if remove_prev:
|
147
|
-
out = out[:, 1:]
|
148
|
-
|
149
|
-
return inverse_pack_weight_shape(out)
|
150
|
-
|
151
|
-
from accelerated_scan.triton import scan as triton_scan
|
152
|
-
from accelerated_scan.warp import scan as warp_scan
|
153
|
-
|
154
|
-
scan = triton_scan if gates.is_cuda else warp_scan
|
155
|
-
|
156
|
-
def accelerate_scan_fn(gates, inputs):
|
157
|
-
gates = gates.expand_as(inputs)
|
158
|
-
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
159
|
-
|
160
|
-
seq_len = gates.shape[-1]
|
161
|
-
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
162
|
-
|
163
|
-
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
164
|
-
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
165
|
-
|
166
|
-
outputs = scan(gates.contiguous(), inputs.contiguous())
|
167
|
-
|
168
|
-
outputs = outputs[..., :seq_len]
|
169
|
-
outputs = rearrange(outputs, 'b d n -> b n d')
|
170
|
-
|
171
|
-
return outputs
|
172
|
-
|
173
|
-
out = accelerate_scan_fn(gates, inputs)
|
174
|
-
|
175
|
-
if remove_prev:
|
176
|
-
out = out[:, 1:]
|
177
|
-
|
178
|
-
return inverse_pack_weight_shape(out)
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=sVTOuRUkaIYabFExdLY6s1qXm1UwHHz_J19H8ZV-X74,338
|
2
|
-
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
|
4
|
-
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
-
titans_pytorch/neural_memory.py,sha256=CsB-Wd3T_N_HluRLjjwB2gPdlLymJmmhuxiRJhQuXkA,33377
|
6
|
-
titans_pytorch-0.4.6.dist-info/METADATA,sha256=aWhQMQrjBLzUPmvtH-LY47r4ayz_ts70gGyDMMyJ6Sc,6810
|
7
|
-
titans_pytorch-0.4.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.4.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.4.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|