titans-pytorch 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +21 -12
- {titans_pytorch-0.0.23.dist-info → titans_pytorch-0.0.25.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.25.dist-info/RECORD +8 -0
- titans_pytorch-0.0.23.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.23.dist-info → titans_pytorch-0.0.25.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.23.dist-info → titans_pytorch-0.0.25.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
@@ -73,6 +73,17 @@ def softclamp_grad_norm(t, max_value):
|
|
73
73
|
t = t * (clamped_norm / norm)
|
74
74
|
return inverse(t)
|
75
75
|
|
76
|
+
# multi head rmsnorm
|
77
|
+
|
78
|
+
class MultiheadRMSNorm(Module):
|
79
|
+
def __init__(self, dim, heads):
|
80
|
+
super().__init__()
|
81
|
+
self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
|
82
|
+
self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
|
83
|
+
|
84
|
+
def forward(self, x):
|
85
|
+
return self.rmsnorm(x) * (self.gamma + 1.)
|
86
|
+
|
76
87
|
# classes
|
77
88
|
|
78
89
|
class MemoryMLP(Module):
|
@@ -125,26 +136,28 @@ class NeuralMemory(Module):
|
|
125
136
|
)
|
126
137
|
):
|
127
138
|
super().__init__()
|
139
|
+
dim_head = default(dim_head, dim)
|
128
140
|
|
129
141
|
# norms
|
130
142
|
|
131
143
|
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
132
144
|
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
133
145
|
|
134
|
-
self.
|
146
|
+
self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
|
135
147
|
|
136
148
|
# maybe multi-headed
|
137
149
|
|
138
|
-
dim_head = default(dim_head, dim)
|
139
150
|
dim_inner = dim_head * heads
|
140
151
|
|
152
|
+
self.heads = heads
|
153
|
+
|
141
154
|
self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
|
142
|
-
self.merge_heads = Rearrange('
|
155
|
+
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
143
156
|
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
144
157
|
|
145
158
|
self.retrieve_gate = nn.Sequential(
|
146
159
|
LinearNoBias(dim, heads),
|
147
|
-
Rearrange('b n h ->
|
160
|
+
Rearrange('b n h -> b h n 1'),
|
148
161
|
nn.Sigmoid()
|
149
162
|
) if heads > 1 else None
|
150
163
|
|
@@ -364,7 +377,7 @@ class NeuralMemory(Module):
|
|
364
377
|
past_weights: dict[str, Tensor] | None = None,
|
365
378
|
):
|
366
379
|
chunk_size = self.chunk_size
|
367
|
-
seq_len = seq.shape[
|
380
|
+
batch, seq_len = seq.shape[:2]
|
368
381
|
|
369
382
|
seq = self.retrieve_norm(seq)
|
370
383
|
|
@@ -398,8 +411,6 @@ class NeuralMemory(Module):
|
|
398
411
|
|
399
412
|
queries = self.split_heads(queries)
|
400
413
|
|
401
|
-
batch = queries.shape[0]
|
402
|
-
|
403
414
|
# fetch values from memory model
|
404
415
|
|
405
416
|
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
@@ -411,7 +422,9 @@ class NeuralMemory(Module):
|
|
411
422
|
|
412
423
|
# reconstitute batch dimension
|
413
424
|
|
414
|
-
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
425
|
+
values = rearrange(values, '(b h n) c d -> b h (n c) d', b = batch, h = self.heads)
|
426
|
+
|
427
|
+
values = self.multihead_rmsnorm(values)
|
415
428
|
|
416
429
|
# maybe gate
|
417
430
|
|
@@ -424,10 +437,6 @@ class NeuralMemory(Module):
|
|
424
437
|
|
425
438
|
values = self.combine_heads(values)
|
426
439
|
|
427
|
-
# post norm, somehow could not stabilize this without it, not in paper
|
428
|
-
|
429
|
-
values = self.post_rmsnorm(values)
|
430
|
-
|
431
440
|
# restore, pad with empty memory embed
|
432
441
|
|
433
442
|
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
4
|
+
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
+
titans_pytorch-0.0.25.dist-info/METADATA,sha256=NeRAldZl9fN7bc3YzL44kDzH2rsq5SMBZ-7RjGs_B0g,3811
|
6
|
+
titans_pytorch-0.0.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.0.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.0.25.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/titans.py,sha256=AU2mf3RkClSRIP0IUwnnqsA5O1udNYGbTRb0lVBLA78,14024
|
4
|
-
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
-
titans_pytorch-0.0.23.dist-info/METADATA,sha256=U80_8U_mwaQqwWKPWlu76-O3-CTCfrZ7t_HdE3Zl_qE,3811
|
6
|
-
titans_pytorch-0.0.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.0.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.0.23.dist-info/RECORD,,
|
File without changes
|
File without changes
|