titans-pytorch 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- titans_pytorch/titans.py +21 -12
- {titans_pytorch-0.0.23.dist-info → titans_pytorch-0.0.25.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.25.dist-info/RECORD +8 -0
- titans_pytorch-0.0.23.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.23.dist-info → titans_pytorch-0.0.25.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.23.dist-info → titans_pytorch-0.0.25.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
@@ -73,6 +73,17 @@ def softclamp_grad_norm(t, max_value):
|
|
73
73
|
t = t * (clamped_norm / norm)
|
74
74
|
return inverse(t)
|
75
75
|
|
76
|
+
# multi head rmsnorm
|
77
|
+
|
78
|
+
class MultiheadRMSNorm(Module):
|
79
|
+
def __init__(self, dim, heads):
|
80
|
+
super().__init__()
|
81
|
+
self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
|
82
|
+
self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
|
83
|
+
|
84
|
+
def forward(self, x):
|
85
|
+
return self.rmsnorm(x) * (self.gamma + 1.)
|
86
|
+
|
76
87
|
# classes
|
77
88
|
|
78
89
|
class MemoryMLP(Module):
|
@@ -125,26 +136,28 @@ class NeuralMemory(Module):
|
|
125
136
|
)
|
126
137
|
):
|
127
138
|
super().__init__()
|
139
|
+
dim_head = default(dim_head, dim)
|
128
140
|
|
129
141
|
# norms
|
130
142
|
|
131
143
|
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
132
144
|
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
133
145
|
|
134
|
-
self.
|
146
|
+
self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
|
135
147
|
|
136
148
|
# maybe multi-headed
|
137
149
|
|
138
|
-
dim_head = default(dim_head, dim)
|
139
150
|
dim_inner = dim_head * heads
|
140
151
|
|
152
|
+
self.heads = heads
|
153
|
+
|
141
154
|
self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
|
142
|
-
self.merge_heads = Rearrange('
|
155
|
+
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
143
156
|
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
144
157
|
|
145
158
|
self.retrieve_gate = nn.Sequential(
|
146
159
|
LinearNoBias(dim, heads),
|
147
|
-
Rearrange('b n h ->
|
160
|
+
Rearrange('b n h -> b h n 1'),
|
148
161
|
nn.Sigmoid()
|
149
162
|
) if heads > 1 else None
|
150
163
|
|
@@ -364,7 +377,7 @@ class NeuralMemory(Module):
|
|
364
377
|
past_weights: dict[str, Tensor] | None = None,
|
365
378
|
):
|
366
379
|
chunk_size = self.chunk_size
|
367
|
-
seq_len = seq.shape[
|
380
|
+
batch, seq_len = seq.shape[:2]
|
368
381
|
|
369
382
|
seq = self.retrieve_norm(seq)
|
370
383
|
|
@@ -398,8 +411,6 @@ class NeuralMemory(Module):
|
|
398
411
|
|
399
412
|
queries = self.split_heads(queries)
|
400
413
|
|
401
|
-
batch = queries.shape[0]
|
402
|
-
|
403
414
|
# fetch values from memory model
|
404
415
|
|
405
416
|
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
@@ -411,7 +422,9 @@ class NeuralMemory(Module):
|
|
411
422
|
|
412
423
|
# reconstitute batch dimension
|
413
424
|
|
414
|
-
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
425
|
+
values = rearrange(values, '(b h n) c d -> b h (n c) d', b = batch, h = self.heads)
|
426
|
+
|
427
|
+
values = self.multihead_rmsnorm(values)
|
415
428
|
|
416
429
|
# maybe gate
|
417
430
|
|
@@ -424,10 +437,6 @@ class NeuralMemory(Module):
|
|
424
437
|
|
425
438
|
values = self.combine_heads(values)
|
426
439
|
|
427
|
-
# post norm, somehow could not stabilize this without it, not in paper
|
428
|
-
|
429
|
-
values = self.post_rmsnorm(values)
|
430
|
-
|
431
440
|
# restore, pad with empty memory embed
|
432
441
|
|
433
442
|
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
4
|
+
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
+
titans_pytorch-0.0.25.dist-info/METADATA,sha256=NeRAldZl9fN7bc3YzL44kDzH2rsq5SMBZ-7RjGs_B0g,3811
|
6
|
+
titans_pytorch-0.0.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.0.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.0.25.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/titans.py,sha256=AU2mf3RkClSRIP0IUwnnqsA5O1udNYGbTRb0lVBLA78,14024
|
4
|
-
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
-
titans_pytorch-0.0.23.dist-info/METADATA,sha256=U80_8U_mwaQqwWKPWlu76-O3-CTCfrZ7t_HdE3Zl_qE,3811
|
6
|
-
titans_pytorch-0.0.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.0.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.0.23.dist-info/RECORD,,
|
File without changes
|
File without changes
|