titans-pytorch 0.0.58__py3-none-any.whl → 0.0.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +2 -1
- titans_pytorch/mac_transformer.py +3 -3
- titans_pytorch/titans.py +64 -12
- {titans_pytorch-0.0.58.dist-info → titans_pytorch-0.0.62.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.62.dist-info/RECORD +8 -0
- titans_pytorch-0.0.58.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.58.dist-info → titans_pytorch-0.0.62.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.58.dist-info → titans_pytorch-0.0.62.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
|
@@ -24,8 +24,8 @@ def create_mac_block_mask(seq_len, window_size, persist_mem_len):
|
|
|
24
24
|
|
|
25
25
|
def create_mac_mask(b, h, q_idx, kv_idx):
|
|
26
26
|
is_persist_mem = kv_idx < persist_mem_len
|
|
27
|
-
causal_mask = q_idx >= (kv_idx -
|
|
28
|
-
block_diagonal = (q_idx // window_size) == ((kv_idx -
|
|
27
|
+
causal_mask = q_idx >= (kv_idx - persist_mem_len)
|
|
28
|
+
block_diagonal = (q_idx // window_size) == ((kv_idx - persist_mem_len) // window_size)
|
|
29
29
|
return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
|
|
30
30
|
|
|
31
31
|
block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
|
|
@@ -489,7 +489,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
489
489
|
flex_attn_fn = None
|
|
490
490
|
|
|
491
491
|
if use_flex_attn:
|
|
492
|
-
block_mask = create_mac_block_mask(seq_len_with_mem,
|
|
492
|
+
block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens)
|
|
493
493
|
flex_attn_fn = partial(flex_attention, block_mask = block_mask)
|
|
494
494
|
|
|
495
495
|
# value residual
|
titans_pytorch/titans.py
CHANGED
|
@@ -6,7 +6,7 @@ from functools import partial
|
|
|
6
6
|
import torch
|
|
7
7
|
from torch import nn, Tensor
|
|
8
8
|
import torch.nn.functional as F
|
|
9
|
-
from torch.nn import Linear, Module
|
|
9
|
+
from torch.nn import Linear, Module, Parameter, ParameterList
|
|
10
10
|
from torch.func import functional_call, vmap, grad
|
|
11
11
|
|
|
12
12
|
from tensordict import TensorDict
|
|
@@ -88,7 +88,7 @@ class MultiheadRMSNorm(Module):
|
|
|
88
88
|
def __init__(self, dim, heads):
|
|
89
89
|
super().__init__()
|
|
90
90
|
self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
|
|
91
|
-
self.gamma =
|
|
91
|
+
self.gamma = Parameter(torch.zeros(heads, 1, dim))
|
|
92
92
|
|
|
93
93
|
def forward(self, x):
|
|
94
94
|
return self.rmsnorm(x) * (self.gamma + 1.)
|
|
@@ -102,7 +102,10 @@ class MemoryMLP(Module):
|
|
|
102
102
|
depth
|
|
103
103
|
):
|
|
104
104
|
super().__init__()
|
|
105
|
-
self.weights =
|
|
105
|
+
self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
|
|
106
|
+
|
|
107
|
+
for weight in self.weights:
|
|
108
|
+
nn.init.xavier_uniform_(weight)
|
|
106
109
|
|
|
107
110
|
def forward(
|
|
108
111
|
self,
|
|
@@ -118,13 +121,50 @@ class MemoryMLP(Module):
|
|
|
118
121
|
|
|
119
122
|
return x
|
|
120
123
|
|
|
124
|
+
# memory mlp with factorized weights
|
|
125
|
+
# so can tradeoff capacity for smaller chunk sizes
|
|
126
|
+
|
|
127
|
+
class FactorizedMemoryMLP(Module):
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
dim,
|
|
131
|
+
depth,
|
|
132
|
+
k = 32
|
|
133
|
+
):
|
|
134
|
+
super().__init__()
|
|
135
|
+
self.weights = ParameterList([
|
|
136
|
+
ParameterList([
|
|
137
|
+
Parameter(torch.randn(dim, k)),
|
|
138
|
+
Parameter(torch.randn(k, dim)),
|
|
139
|
+
]) for _ in range(depth)
|
|
140
|
+
])
|
|
141
|
+
|
|
142
|
+
for weight1, weight2 in self.weights:
|
|
143
|
+
nn.init.xavier_uniform_(weight1)
|
|
144
|
+
nn.init.xavier_uniform_(weight2)
|
|
145
|
+
|
|
146
|
+
def forward(
|
|
147
|
+
self,
|
|
148
|
+
x
|
|
149
|
+
):
|
|
150
|
+
for ind, (weight1, weight2) in enumerate(self.weights):
|
|
151
|
+
is_first = ind == 0
|
|
152
|
+
|
|
153
|
+
if not is_first:
|
|
154
|
+
x = F.silu(x)
|
|
155
|
+
|
|
156
|
+
x = x @ weight1 @ weight2
|
|
157
|
+
|
|
158
|
+
return x
|
|
159
|
+
|
|
121
160
|
# improvised attention as memory module
|
|
122
161
|
|
|
123
162
|
class MemoryAttention(Module):
|
|
124
163
|
def __init__(
|
|
125
164
|
self,
|
|
126
165
|
dim,
|
|
127
|
-
scale = 8
|
|
166
|
+
scale = 8.,
|
|
167
|
+
expansion_factor = 2.
|
|
128
168
|
):
|
|
129
169
|
super().__init__()
|
|
130
170
|
self.scale = scale
|
|
@@ -133,10 +173,13 @@ class MemoryAttention(Module):
|
|
|
133
173
|
nn.Parameter(torch.randn(dim, dim)), # queries
|
|
134
174
|
nn.Parameter(torch.randn(dim, dim)), # keys
|
|
135
175
|
nn.Parameter(torch.randn(dim, dim)), # values
|
|
136
|
-
nn.Parameter(torch.randn(dim, dim *
|
|
137
|
-
nn.Parameter(torch.randn(dim *
|
|
176
|
+
nn.Parameter(torch.randn(dim, dim * expansion_factor)), # ff w1
|
|
177
|
+
nn.Parameter(torch.randn(dim * expansion_factor, dim)), # ff w2
|
|
138
178
|
])
|
|
139
179
|
|
|
180
|
+
for weight in self.weights:
|
|
181
|
+
nn.init.xavier_uniform_(weight)
|
|
182
|
+
|
|
140
183
|
def forward(self, x):
|
|
141
184
|
wq, wk, wv, ffw1, ffw2 = self.weights
|
|
142
185
|
|
|
@@ -289,6 +332,8 @@ class NeuralMemory(Module):
|
|
|
289
332
|
|
|
290
333
|
self.use_accelerated_scan = use_accelerated_scan
|
|
291
334
|
|
|
335
|
+
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
|
336
|
+
|
|
292
337
|
def init_weights_and_momentum(self):
|
|
293
338
|
params = TensorDict(dict(self.memory_model.named_parameters()))
|
|
294
339
|
|
|
@@ -306,6 +351,13 @@ class NeuralMemory(Module):
|
|
|
306
351
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
|
|
307
352
|
return_aux_kv_loss = False
|
|
308
353
|
):
|
|
354
|
+
seq_len = seq.shape[-2]
|
|
355
|
+
|
|
356
|
+
# handle edge case
|
|
357
|
+
|
|
358
|
+
if seq_len < self.chunk_size:
|
|
359
|
+
past_weight, _ = past_state
|
|
360
|
+
return TensorDict(past_weight).clone().zero_(), self.zero
|
|
309
361
|
|
|
310
362
|
seq = self.store_norm(seq)
|
|
311
363
|
|
|
@@ -425,12 +477,10 @@ class NeuralMemory(Module):
|
|
|
425
477
|
|
|
426
478
|
last_update = updates.apply(lambda t: t[:, -1])
|
|
427
479
|
|
|
428
|
-
next_state = (curr_weights + last_update, next_momentum)
|
|
429
|
-
|
|
430
480
|
if not return_aux_kv_loss:
|
|
431
|
-
return updates
|
|
481
|
+
return updates
|
|
432
482
|
|
|
433
|
-
return updates,
|
|
483
|
+
return updates, aux_kv_recon_loss.mean()
|
|
434
484
|
|
|
435
485
|
def retrieve_memories(
|
|
436
486
|
self,
|
|
@@ -442,7 +492,8 @@ class NeuralMemory(Module):
|
|
|
442
492
|
|
|
443
493
|
seq = self.retrieve_norm(seq)
|
|
444
494
|
|
|
445
|
-
|
|
495
|
+
if seq_len < self.chunk_size:
|
|
496
|
+
return self.init_empty_memory_embed(batch, seq_len)
|
|
446
497
|
|
|
447
498
|
seq = seq[:, (chunk_size - 1):]
|
|
448
499
|
curtailed_seq_len = seq.shape[-2]
|
|
@@ -524,10 +575,11 @@ class NeuralMemory(Module):
|
|
|
524
575
|
|
|
525
576
|
store_seq = default(store_seq, seq)
|
|
526
577
|
|
|
527
|
-
updates,
|
|
578
|
+
updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
|
|
528
579
|
|
|
529
580
|
past_weights, _ = past_state
|
|
530
581
|
|
|
582
|
+
|
|
531
583
|
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
|
532
584
|
|
|
533
585
|
if not return_aux_kv_loss:
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=R9Xu-BjqAe9ZY60IGk4aNXBx_L8THsjJ4QrkbTnRNHo,15346
|
|
4
|
+
titans_pytorch/titans.py,sha256=95J6UL44lOrdZSXdm7p36xC9tDeSmRBwdjig9T82PzI,17452
|
|
5
|
+
titans_pytorch-0.0.62.dist-info/METADATA,sha256=08Blaa9Ehyv09rSA9uWguxbhKpbrd7C53Ya13E1VbpU,4457
|
|
6
|
+
titans_pytorch-0.0.62.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.62.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.62.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=kk8s8Q2WmbJxCVi8PcqSUyJBc8-CDAHrVjt6M0d_kFs,15323
|
|
4
|
-
titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
|
|
5
|
-
titans_pytorch-0.0.58.dist-info/METADATA,sha256=a-Y6MV_89D44HlB7eKpurh-sw5DDiS-pIVei3Uw_uGE,4457
|
|
6
|
-
titans_pytorch-0.0.58.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.58.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.58.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|