titans-pytorch 0.0.61__tar.gz → 0.0.63__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/PKG-INFO +1 -1
- titans_pytorch-0.0.63/assert_flex.py +16 -0
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/pyproject.toml +1 -1
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/titans_pytorch/__init__.py +2 -1
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/titans_pytorch/titans.py +51 -6
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/.gitignore +0 -0
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/LICENSE +0 -0
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/README.md +0 -0
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/data/README.md +0 -0
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/fig1.png +0 -0
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/fig2.png +0 -0
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.0.61 → titans_pytorch-0.0.63}/train_mac.py +0 -0
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from titans_pytorch.mac_transformer import SegmentedAttention
|
|
3
|
+
|
|
4
|
+
attn = SegmentedAttention(
|
|
5
|
+
dim = 512,
|
|
6
|
+
segment_len = 32,
|
|
7
|
+
num_persist_mem_tokens = 1,
|
|
8
|
+
use_flex_attn = True
|
|
9
|
+
).cuda()
|
|
10
|
+
|
|
11
|
+
seq = torch.randn(1, 1019, 512).cuda()
|
|
12
|
+
|
|
13
|
+
out_flex, _ = attn(seq)
|
|
14
|
+
out_non_flex, _ = attn(seq, disable_flex_attn = True)
|
|
15
|
+
|
|
16
|
+
assert torch.allclose(out_flex, out_non_flex, atol = 1e-6)
|
|
@@ -6,7 +6,7 @@ from functools import partial
|
|
|
6
6
|
import torch
|
|
7
7
|
from torch import nn, Tensor
|
|
8
8
|
import torch.nn.functional as F
|
|
9
|
-
from torch.nn import Linear, Module
|
|
9
|
+
from torch.nn import Linear, Module, Parameter, ParameterList
|
|
10
10
|
from torch.func import functional_call, vmap, grad
|
|
11
11
|
|
|
12
12
|
from tensordict import TensorDict
|
|
@@ -88,7 +88,7 @@ class MultiheadRMSNorm(Module):
|
|
|
88
88
|
def __init__(self, dim, heads):
|
|
89
89
|
super().__init__()
|
|
90
90
|
self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
|
|
91
|
-
self.gamma =
|
|
91
|
+
self.gamma = Parameter(torch.zeros(heads, 1, dim))
|
|
92
92
|
|
|
93
93
|
def forward(self, x):
|
|
94
94
|
return self.rmsnorm(x) * (self.gamma + 1.)
|
|
@@ -102,7 +102,10 @@ class MemoryMLP(Module):
|
|
|
102
102
|
depth
|
|
103
103
|
):
|
|
104
104
|
super().__init__()
|
|
105
|
-
self.weights =
|
|
105
|
+
self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
|
|
106
|
+
|
|
107
|
+
for weight in self.weights:
|
|
108
|
+
nn.init.xavier_uniform_(weight)
|
|
106
109
|
|
|
107
110
|
def forward(
|
|
108
111
|
self,
|
|
@@ -118,25 +121,66 @@ class MemoryMLP(Module):
|
|
|
118
121
|
|
|
119
122
|
return x
|
|
120
123
|
|
|
124
|
+
# memory mlp with factorized weights
|
|
125
|
+
# so can tradeoff capacity for smaller chunk sizes
|
|
126
|
+
|
|
127
|
+
class FactorizedMemoryMLP(Module):
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
dim,
|
|
131
|
+
depth,
|
|
132
|
+
k = 32
|
|
133
|
+
):
|
|
134
|
+
super().__init__()
|
|
135
|
+
self.weights = ParameterList([
|
|
136
|
+
ParameterList([
|
|
137
|
+
Parameter(torch.randn(dim, k)),
|
|
138
|
+
Parameter(torch.randn(k, dim)),
|
|
139
|
+
]) for _ in range(depth)
|
|
140
|
+
])
|
|
141
|
+
|
|
142
|
+
for weight1, weight2 in self.weights:
|
|
143
|
+
nn.init.xavier_uniform_(weight1)
|
|
144
|
+
nn.init.xavier_uniform_(weight2)
|
|
145
|
+
|
|
146
|
+
def forward(
|
|
147
|
+
self,
|
|
148
|
+
x
|
|
149
|
+
):
|
|
150
|
+
for ind, (weight1, weight2) in enumerate(self.weights):
|
|
151
|
+
is_first = ind == 0
|
|
152
|
+
|
|
153
|
+
if not is_first:
|
|
154
|
+
x = F.silu(x)
|
|
155
|
+
|
|
156
|
+
x = x @ weight1 @ weight2
|
|
157
|
+
|
|
158
|
+
return x
|
|
159
|
+
|
|
121
160
|
# improvised attention as memory module
|
|
122
161
|
|
|
123
162
|
class MemoryAttention(Module):
|
|
124
163
|
def __init__(
|
|
125
164
|
self,
|
|
126
165
|
dim,
|
|
127
|
-
scale = 8
|
|
166
|
+
scale = 8.,
|
|
167
|
+
expansion_factor = 2.
|
|
128
168
|
):
|
|
129
169
|
super().__init__()
|
|
130
170
|
self.scale = scale
|
|
171
|
+
dim_ff_hidden = int(dim * expansion_factor)
|
|
131
172
|
|
|
132
173
|
self.weights = nn.ParameterList([
|
|
133
174
|
nn.Parameter(torch.randn(dim, dim)), # queries
|
|
134
175
|
nn.Parameter(torch.randn(dim, dim)), # keys
|
|
135
176
|
nn.Parameter(torch.randn(dim, dim)), # values
|
|
136
|
-
nn.Parameter(torch.randn(dim,
|
|
137
|
-
nn.Parameter(torch.randn(
|
|
177
|
+
nn.Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1
|
|
178
|
+
nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
|
|
138
179
|
])
|
|
139
180
|
|
|
181
|
+
for weight in self.weights:
|
|
182
|
+
nn.init.xavier_uniform_(weight)
|
|
183
|
+
|
|
140
184
|
def forward(self, x):
|
|
141
185
|
wq, wk, wv, ffw1, ffw2 = self.weights
|
|
142
186
|
|
|
@@ -536,6 +580,7 @@ class NeuralMemory(Module):
|
|
|
536
580
|
|
|
537
581
|
past_weights, _ = past_state
|
|
538
582
|
|
|
583
|
+
|
|
539
584
|
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
|
540
585
|
|
|
541
586
|
if not return_aux_kv_loss:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|