titans-pytorch 0.0.23__tar.gz → 0.0.24__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/PKG-INFO +1 -1
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/pyproject.toml +1 -1
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/titans_pytorch/titans.py +20 -11
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/train.py +1 -1
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/.gitignore +0 -0
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/LICENSE +0 -0
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/README.md +0 -0
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/data/README.md +0 -0
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/fig1.png +0 -0
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/fig2.png +0 -0
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/requirements.txt +0 -0
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.23 → titans_pytorch-0.0.24}/titans_pytorch/titans_attn_memory.py +0 -0
@@ -73,6 +73,17 @@ def softclamp_grad_norm(t, max_value):
|
|
73
73
|
t = t * (clamped_norm / norm)
|
74
74
|
return inverse(t)
|
75
75
|
|
76
|
+
# multi head rmsnorm
|
77
|
+
|
78
|
+
class MultiheadRMSNorm(Module):
|
79
|
+
def __init__(self, dim, heads):
|
80
|
+
super().__init__()
|
81
|
+
self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
|
82
|
+
self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
|
83
|
+
|
84
|
+
def forward(self, x):
|
85
|
+
return self.rmsnorm(x) * (self.gamma + 1.)
|
86
|
+
|
76
87
|
# classes
|
77
88
|
|
78
89
|
class MemoryMLP(Module):
|
@@ -131,20 +142,22 @@ class NeuralMemory(Module):
|
|
131
142
|
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
132
143
|
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
133
144
|
|
134
|
-
self.
|
145
|
+
self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
|
135
146
|
|
136
147
|
# maybe multi-headed
|
137
148
|
|
138
149
|
dim_head = default(dim_head, dim)
|
139
150
|
dim_inner = dim_head * heads
|
140
151
|
|
152
|
+
self.heads = heads
|
153
|
+
|
141
154
|
self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
|
142
|
-
self.merge_heads = Rearrange('
|
155
|
+
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
143
156
|
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
144
157
|
|
145
158
|
self.retrieve_gate = nn.Sequential(
|
146
159
|
LinearNoBias(dim, heads),
|
147
|
-
Rearrange('b n h ->
|
160
|
+
Rearrange('b n h -> b h n 1'),
|
148
161
|
nn.Sigmoid()
|
149
162
|
) if heads > 1 else None
|
150
163
|
|
@@ -364,7 +377,7 @@ class NeuralMemory(Module):
|
|
364
377
|
past_weights: dict[str, Tensor] | None = None,
|
365
378
|
):
|
366
379
|
chunk_size = self.chunk_size
|
367
|
-
seq_len = seq.shape[
|
380
|
+
batch, seq_len = seq.shape[:2]
|
368
381
|
|
369
382
|
seq = self.retrieve_norm(seq)
|
370
383
|
|
@@ -398,8 +411,6 @@ class NeuralMemory(Module):
|
|
398
411
|
|
399
412
|
queries = self.split_heads(queries)
|
400
413
|
|
401
|
-
batch = queries.shape[0]
|
402
|
-
|
403
414
|
# fetch values from memory model
|
404
415
|
|
405
416
|
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
@@ -411,7 +422,9 @@ class NeuralMemory(Module):
|
|
411
422
|
|
412
423
|
# reconstitute batch dimension
|
413
424
|
|
414
|
-
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
425
|
+
values = rearrange(values, '(b h n) c d -> b h (n c) d', b = batch, h = self.heads)
|
426
|
+
|
427
|
+
values = self.multihead_rmsnorm(values)
|
415
428
|
|
416
429
|
# maybe gate
|
417
430
|
|
@@ -424,10 +437,6 @@ class NeuralMemory(Module):
|
|
424
437
|
|
425
438
|
values = self.combine_heads(values)
|
426
439
|
|
427
|
-
# post norm, somehow could not stabilize this without it, not in paper
|
428
|
-
|
429
|
-
values = self.post_rmsnorm(values)
|
430
|
-
|
431
440
|
# restore, pad with empty memory embed
|
432
441
|
|
433
442
|
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|