titans-pytorch 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +41 -19
- {titans_pytorch-0.0.22.dist-info → titans_pytorch-0.0.24.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.24.dist-info/RECORD +8 -0
- titans_pytorch-0.0.22.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.22.dist-info → titans_pytorch-0.0.24.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.22.dist-info → titans_pytorch-0.0.24.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
@@ -73,6 +73,17 @@ def softclamp_grad_norm(t, max_value):
|
|
73
73
|
t = t * (clamped_norm / norm)
|
74
74
|
return inverse(t)
|
75
75
|
|
76
|
+
# multi head rmsnorm
|
77
|
+
|
78
|
+
class MultiheadRMSNorm(Module):
|
79
|
+
def __init__(self, dim, heads):
|
80
|
+
super().__init__()
|
81
|
+
self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
|
82
|
+
self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
|
83
|
+
|
84
|
+
def forward(self, x):
|
85
|
+
return self.rmsnorm(x) * (self.gamma + 1.)
|
86
|
+
|
76
87
|
# classes
|
77
88
|
|
78
89
|
class MemoryMLP(Module):
|
@@ -100,11 +111,11 @@ class MemoryMLP(Module):
|
|
100
111
|
|
101
112
|
# main neural memory
|
102
113
|
|
103
|
-
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-
|
114
|
+
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
|
104
115
|
return adaptive_step.sigmoid() * max_lr
|
105
116
|
|
106
117
|
def default_loss_fn(pred, target):
|
107
|
-
return (pred - target).pow(2).mean(dim = -1)
|
118
|
+
return (pred - target).pow(2).mean(dim = -1)
|
108
119
|
|
109
120
|
class NeuralMemory(Module):
|
110
121
|
def __init__(
|
@@ -131,17 +142,25 @@ class NeuralMemory(Module):
|
|
131
142
|
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
132
143
|
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
133
144
|
|
134
|
-
self.
|
145
|
+
self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
|
135
146
|
|
136
147
|
# maybe multi-headed
|
137
148
|
|
138
149
|
dim_head = default(dim_head, dim)
|
139
150
|
dim_inner = dim_head * heads
|
140
151
|
|
152
|
+
self.heads = heads
|
153
|
+
|
141
154
|
self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
|
142
|
-
self.merge_heads = Rearrange('
|
155
|
+
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
143
156
|
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
144
157
|
|
158
|
+
self.retrieve_gate = nn.Sequential(
|
159
|
+
LinearNoBias(dim, heads),
|
160
|
+
Rearrange('b n h -> b h n 1'),
|
161
|
+
nn.Sigmoid()
|
162
|
+
) if heads > 1 else None
|
163
|
+
|
145
164
|
# memory mlp
|
146
165
|
|
147
166
|
if not exists(model):
|
@@ -159,12 +178,13 @@ class NeuralMemory(Module):
|
|
159
178
|
|
160
179
|
# prepare function for per sample gradients from model above, using torch.func
|
161
180
|
|
162
|
-
def forward_and_loss(params, inputs, target):
|
181
|
+
def forward_and_loss(params, inputs, loss_weights, target):
|
163
182
|
pred = functional_call(self.memory_model, params, inputs)
|
164
183
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
165
|
-
|
184
|
+
loss = loss * loss_weights
|
185
|
+
return loss.sum()
|
166
186
|
|
167
|
-
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
|
187
|
+
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))
|
168
188
|
|
169
189
|
# queries for retrieving from the model
|
170
190
|
|
@@ -190,7 +210,6 @@ class NeuralMemory(Module):
|
|
190
210
|
)
|
191
211
|
|
192
212
|
self.to_adaptive_step = nn.Sequential(
|
193
|
-
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
194
213
|
LinearNoBias(dim, heads),
|
195
214
|
Rearrange('b n h -> (b h) n')
|
196
215
|
)
|
@@ -271,9 +290,11 @@ class NeuralMemory(Module):
|
|
271
290
|
|
272
291
|
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
|
273
292
|
|
293
|
+
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = self.chunk_size)
|
294
|
+
|
274
295
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
275
296
|
|
276
|
-
grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
|
297
|
+
grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
|
277
298
|
|
278
299
|
grads = TensorDict(grads)
|
279
300
|
|
@@ -286,9 +307,9 @@ class NeuralMemory(Module):
|
|
286
307
|
|
287
308
|
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
288
309
|
|
289
|
-
#
|
310
|
+
# negative gradients, adaptive lr already applied as loss weight
|
290
311
|
|
291
|
-
surprises = grads.apply(lambda t:
|
312
|
+
surprises = grads.apply(lambda t: -t)
|
292
313
|
|
293
314
|
# determine scan function
|
294
315
|
|
@@ -356,7 +377,7 @@ class NeuralMemory(Module):
|
|
356
377
|
past_weights: dict[str, Tensor] | None = None,
|
357
378
|
):
|
358
379
|
chunk_size = self.chunk_size
|
359
|
-
seq_len = seq.shape[
|
380
|
+
batch, seq_len = seq.shape[:2]
|
360
381
|
|
361
382
|
seq = self.retrieve_norm(seq)
|
362
383
|
|
@@ -390,8 +411,6 @@ class NeuralMemory(Module):
|
|
390
411
|
|
391
412
|
queries = self.split_heads(queries)
|
392
413
|
|
393
|
-
batch = queries.shape[0]
|
394
|
-
|
395
414
|
# fetch values from memory model
|
396
415
|
|
397
416
|
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
@@ -403,7 +422,14 @@ class NeuralMemory(Module):
|
|
403
422
|
|
404
423
|
# reconstitute batch dimension
|
405
424
|
|
406
|
-
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
425
|
+
values = rearrange(values, '(b h n) c d -> b h (n c) d', b = batch, h = self.heads)
|
426
|
+
|
427
|
+
values = self.multihead_rmsnorm(values)
|
428
|
+
|
429
|
+
# maybe gate
|
430
|
+
|
431
|
+
if exists(self.retrieve_gate):
|
432
|
+
values = values * self.retrieve_gate(seq)
|
407
433
|
|
408
434
|
# maybe merge heads and combine
|
409
435
|
|
@@ -411,10 +437,6 @@ class NeuralMemory(Module):
|
|
411
437
|
|
412
438
|
values = self.combine_heads(values)
|
413
439
|
|
414
|
-
# post norm, somehow could not stabilize this without it, not in paper
|
415
|
-
|
416
|
-
values = self.post_rmsnorm(values)
|
417
|
-
|
418
440
|
# restore, pad with empty memory embed
|
419
441
|
|
420
442
|
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/titans.py,sha256=3UxRZl_uwQBly11jQAWjfnNzHSoOUKiw-Ux2lXu2ilI,14304
|
4
|
+
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
+
titans_pytorch-0.0.24.dist-info/METADATA,sha256=WGxo4oVx9HCq7LvSH8u_isp1tjxVXb3Ao_GrgjdFzSo,3811
|
6
|
+
titans_pytorch-0.0.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.0.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.0.24.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/titans.py,sha256=s7KVqId7hhHw1Ck77FJUXfKC5rpwS4N7Nw2mKFtP8-s,13677
|
4
|
-
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
-
titans_pytorch-0.0.22.dist-info/METADATA,sha256=acTS0vWh84sGtTIEBzdTsxNi-S4V-Cr8NL4U4Vg0eY0,3811
|
6
|
-
titans_pytorch-0.0.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.0.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.0.22.dist-info/RECORD,,
|
File without changes
|
File without changes
|