titans-pytorch 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- titans_pytorch/titans.py +41 -19
- {titans_pytorch-0.0.22.dist-info → titans_pytorch-0.0.24.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.24.dist-info/RECORD +8 -0
- titans_pytorch-0.0.22.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.22.dist-info → titans_pytorch-0.0.24.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.22.dist-info → titans_pytorch-0.0.24.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
@@ -73,6 +73,17 @@ def softclamp_grad_norm(t, max_value):
|
|
73
73
|
t = t * (clamped_norm / norm)
|
74
74
|
return inverse(t)
|
75
75
|
|
76
|
+
# multi head rmsnorm
|
77
|
+
|
78
|
+
class MultiheadRMSNorm(Module):
|
79
|
+
def __init__(self, dim, heads):
|
80
|
+
super().__init__()
|
81
|
+
self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
|
82
|
+
self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
|
83
|
+
|
84
|
+
def forward(self, x):
|
85
|
+
return self.rmsnorm(x) * (self.gamma + 1.)
|
86
|
+
|
76
87
|
# classes
|
77
88
|
|
78
89
|
class MemoryMLP(Module):
|
@@ -100,11 +111,11 @@ class MemoryMLP(Module):
|
|
100
111
|
|
101
112
|
# main neural memory
|
102
113
|
|
103
|
-
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-
|
114
|
+
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
|
104
115
|
return adaptive_step.sigmoid() * max_lr
|
105
116
|
|
106
117
|
def default_loss_fn(pred, target):
|
107
|
-
return (pred - target).pow(2).mean(dim = -1)
|
118
|
+
return (pred - target).pow(2).mean(dim = -1)
|
108
119
|
|
109
120
|
class NeuralMemory(Module):
|
110
121
|
def __init__(
|
@@ -131,17 +142,25 @@ class NeuralMemory(Module):
|
|
131
142
|
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
132
143
|
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
133
144
|
|
134
|
-
self.
|
145
|
+
self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
|
135
146
|
|
136
147
|
# maybe multi-headed
|
137
148
|
|
138
149
|
dim_head = default(dim_head, dim)
|
139
150
|
dim_inner = dim_head * heads
|
140
151
|
|
152
|
+
self.heads = heads
|
153
|
+
|
141
154
|
self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
|
142
|
-
self.merge_heads = Rearrange('
|
155
|
+
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
143
156
|
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
144
157
|
|
158
|
+
self.retrieve_gate = nn.Sequential(
|
159
|
+
LinearNoBias(dim, heads),
|
160
|
+
Rearrange('b n h -> b h n 1'),
|
161
|
+
nn.Sigmoid()
|
162
|
+
) if heads > 1 else None
|
163
|
+
|
145
164
|
# memory mlp
|
146
165
|
|
147
166
|
if not exists(model):
|
@@ -159,12 +178,13 @@ class NeuralMemory(Module):
|
|
159
178
|
|
160
179
|
# prepare function for per sample gradients from model above, using torch.func
|
161
180
|
|
162
|
-
def forward_and_loss(params, inputs, target):
|
181
|
+
def forward_and_loss(params, inputs, loss_weights, target):
|
163
182
|
pred = functional_call(self.memory_model, params, inputs)
|
164
183
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
165
|
-
|
184
|
+
loss = loss * loss_weights
|
185
|
+
return loss.sum()
|
166
186
|
|
167
|
-
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
|
187
|
+
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))
|
168
188
|
|
169
189
|
# queries for retrieving from the model
|
170
190
|
|
@@ -190,7 +210,6 @@ class NeuralMemory(Module):
|
|
190
210
|
)
|
191
211
|
|
192
212
|
self.to_adaptive_step = nn.Sequential(
|
193
|
-
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
194
213
|
LinearNoBias(dim, heads),
|
195
214
|
Rearrange('b n h -> (b h) n')
|
196
215
|
)
|
@@ -271,9 +290,11 @@ class NeuralMemory(Module):
|
|
271
290
|
|
272
291
|
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
|
273
292
|
|
293
|
+
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = self.chunk_size)
|
294
|
+
|
274
295
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
275
296
|
|
276
|
-
grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
|
297
|
+
grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
|
277
298
|
|
278
299
|
grads = TensorDict(grads)
|
279
300
|
|
@@ -286,9 +307,9 @@ class NeuralMemory(Module):
|
|
286
307
|
|
287
308
|
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
288
309
|
|
289
|
-
#
|
310
|
+
# negative gradients, adaptive lr already applied as loss weight
|
290
311
|
|
291
|
-
surprises = grads.apply(lambda t:
|
312
|
+
surprises = grads.apply(lambda t: -t)
|
292
313
|
|
293
314
|
# determine scan function
|
294
315
|
|
@@ -356,7 +377,7 @@ class NeuralMemory(Module):
|
|
356
377
|
past_weights: dict[str, Tensor] | None = None,
|
357
378
|
):
|
358
379
|
chunk_size = self.chunk_size
|
359
|
-
seq_len = seq.shape[
|
380
|
+
batch, seq_len = seq.shape[:2]
|
360
381
|
|
361
382
|
seq = self.retrieve_norm(seq)
|
362
383
|
|
@@ -390,8 +411,6 @@ class NeuralMemory(Module):
|
|
390
411
|
|
391
412
|
queries = self.split_heads(queries)
|
392
413
|
|
393
|
-
batch = queries.shape[0]
|
394
|
-
|
395
414
|
# fetch values from memory model
|
396
415
|
|
397
416
|
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
@@ -403,7 +422,14 @@ class NeuralMemory(Module):
|
|
403
422
|
|
404
423
|
# reconstitute batch dimension
|
405
424
|
|
406
|
-
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
425
|
+
values = rearrange(values, '(b h n) c d -> b h (n c) d', b = batch, h = self.heads)
|
426
|
+
|
427
|
+
values = self.multihead_rmsnorm(values)
|
428
|
+
|
429
|
+
# maybe gate
|
430
|
+
|
431
|
+
if exists(self.retrieve_gate):
|
432
|
+
values = values * self.retrieve_gate(seq)
|
407
433
|
|
408
434
|
# maybe merge heads and combine
|
409
435
|
|
@@ -411,10 +437,6 @@ class NeuralMemory(Module):
|
|
411
437
|
|
412
438
|
values = self.combine_heads(values)
|
413
439
|
|
414
|
-
# post norm, somehow could not stabilize this without it, not in paper
|
415
|
-
|
416
|
-
values = self.post_rmsnorm(values)
|
417
|
-
|
418
440
|
# restore, pad with empty memory embed
|
419
441
|
|
420
442
|
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/titans.py,sha256=3UxRZl_uwQBly11jQAWjfnNzHSoOUKiw-Ux2lXu2ilI,14304
|
4
|
+
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
+
titans_pytorch-0.0.24.dist-info/METADATA,sha256=WGxo4oVx9HCq7LvSH8u_isp1tjxVXb3Ao_GrgjdFzSo,3811
|
6
|
+
titans_pytorch-0.0.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.0.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.0.24.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/titans.py,sha256=s7KVqId7hhHw1Ck77FJUXfKC5rpwS4N7Nw2mKFtP8-s,13677
|
4
|
-
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
-
titans_pytorch-0.0.22.dist-info/METADATA,sha256=acTS0vWh84sGtTIEBzdTsxNi-S4V-Cr8NL4U4Vg0eY0,3811
|
6
|
-
titans_pytorch-0.0.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.0.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.0.22.dist-info/RECORD,,
|
File without changes
|
File without changes
|