titans-pytorch 0.0.53__tar.gz → 0.0.55__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/PKG-INFO +1 -1
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/pyproject.toml +1 -1
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/tests/test_titans.py +4 -1
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/titans_pytorch/__init__.py +1 -0
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/titans_pytorch/mac_transformer.py +30 -5
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/titans_pytorch/titans.py +9 -1
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/train_mac.py +5 -2
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/.gitignore +0 -0
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/LICENSE +0 -0
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/README.md +0 -0
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/data/README.md +0 -0
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/fig1.png +0 -0
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/fig2.png +0 -0
- {titans_pytorch-0.0.53 → titans_pytorch-0.0.55}/titans_pytorch/associative_scan.py +0 -0
|
@@ -6,17 +6,20 @@ from titans_pytorch import NeuralMemory
|
|
|
6
6
|
|
|
7
7
|
@pytest.mark.parametrize('seq_len', (32, 1024, 77))
|
|
8
8
|
@pytest.mark.parametrize('silu', (False, True))
|
|
9
|
+
@pytest.mark.parametrize('learned_mem_model_weights', (False, True))
|
|
9
10
|
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
|
10
11
|
def test_titans(
|
|
11
12
|
seq_len,
|
|
12
13
|
silu,
|
|
14
|
+
learned_mem_model_weights,
|
|
13
15
|
max_grad_norm,
|
|
14
16
|
):
|
|
15
17
|
mem = NeuralMemory(
|
|
16
18
|
dim = 384,
|
|
17
19
|
chunk_size = 64,
|
|
18
20
|
activation = nn.SiLU() if silu else None,
|
|
19
|
-
max_grad_norm = max_grad_norm
|
|
21
|
+
max_grad_norm = max_grad_norm,
|
|
22
|
+
learned_mem_model_weights = learned_mem_model_weights
|
|
20
23
|
)
|
|
21
24
|
|
|
22
25
|
seq = torch.randn(2, seq_len, 384)
|
|
@@ -227,9 +227,10 @@ class SegmentedAttention(Module):
|
|
|
227
227
|
self,
|
|
228
228
|
seq,
|
|
229
229
|
value_residual = None,
|
|
230
|
-
flex_attn_fn: Callable | None = None
|
|
230
|
+
flex_attn_fn: Callable | None = None,
|
|
231
|
+
disable_flex_attn = False
|
|
231
232
|
):
|
|
232
|
-
if seq.is_cuda and self.use_flex_attn:
|
|
233
|
+
if seq.is_cuda and self.use_flex_attn and not disable_flex_attn:
|
|
233
234
|
return self.forward_flex(seq, value_residual, flex_attn_fn)
|
|
234
235
|
|
|
235
236
|
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
@@ -303,7 +304,8 @@ class MemoryAsContextTransformer(Module):
|
|
|
303
304
|
num_residual_streams = 4,
|
|
304
305
|
neural_memory_kwargs: dict = dict(),
|
|
305
306
|
neural_memory_layers: tuple[int, ...] | None = None,
|
|
306
|
-
aux_kv_recon_loss_weight = 0
|
|
307
|
+
aux_kv_recon_loss_weight = 0.,
|
|
308
|
+
use_flex_attn = False
|
|
307
309
|
):
|
|
308
310
|
super().__init__()
|
|
309
311
|
|
|
@@ -336,6 +338,8 @@ class MemoryAsContextTransformer(Module):
|
|
|
336
338
|
|
|
337
339
|
assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
|
|
338
340
|
|
|
341
|
+
# mem, attn, and feedforward layers
|
|
342
|
+
|
|
339
343
|
for layer in layers:
|
|
340
344
|
is_first = layer == 1
|
|
341
345
|
|
|
@@ -363,6 +367,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
363
367
|
dim_head = dim_head,
|
|
364
368
|
heads = heads,
|
|
365
369
|
segment_len = segment_len,
|
|
370
|
+
use_flex_attn = use_flex_attn,
|
|
366
371
|
accept_value_residual = not is_first,
|
|
367
372
|
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
|
368
373
|
num_persist_mem_tokens = num_persist_mem_tokens
|
|
@@ -386,11 +391,20 @@ class MemoryAsContextTransformer(Module):
|
|
|
386
391
|
|
|
387
392
|
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
|
388
393
|
|
|
394
|
+
# flex attn related
|
|
395
|
+
|
|
396
|
+
assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
|
|
397
|
+
self.use_flex_attn = use_flex_attn
|
|
398
|
+
|
|
399
|
+
self.segment_len = segment_len
|
|
400
|
+
self.num_persist_mem_tokens = num_persist_mem_tokens
|
|
401
|
+
|
|
389
402
|
def forward(
|
|
390
403
|
self,
|
|
391
404
|
x,
|
|
392
405
|
return_loss = False,
|
|
393
|
-
return_loss_breakdown = False
|
|
406
|
+
return_loss_breakdown = False,
|
|
407
|
+
disable_flex_attn = False
|
|
394
408
|
):
|
|
395
409
|
|
|
396
410
|
if return_loss:
|
|
@@ -424,6 +438,17 @@ class MemoryAsContextTransformer(Module):
|
|
|
424
438
|
|
|
425
439
|
x = x + pos_emb[:seq_len_with_mem]
|
|
426
440
|
|
|
441
|
+
# prep flex attention
|
|
442
|
+
|
|
443
|
+
use_flex_attn = x.is_cuda and self.use_flex_attn and not disable_flex_attn
|
|
444
|
+
|
|
445
|
+
flex_attn_fn = None
|
|
446
|
+
|
|
447
|
+
if use_flex_attn:
|
|
448
|
+
block_mask = create_mac_block_mask(seq_len, self.segment_len, self.num_persist_mem_tokens)
|
|
449
|
+
|
|
450
|
+
flex_attn_fn = partial(flex_attention, block_mask = block_mask)
|
|
451
|
+
|
|
427
452
|
# value residual
|
|
428
453
|
|
|
429
454
|
value_residual = None
|
|
@@ -442,7 +467,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
442
467
|
x, aux_kv_loss = maybe_neural_mem(x, return_aux_kv_loss = True)
|
|
443
468
|
kv_recon_losses = kv_recon_losses + aux_kv_loss
|
|
444
469
|
|
|
445
|
-
x, values = attn(x, value_residual = value_residual)
|
|
470
|
+
x, values = attn(x, value_residual = value_residual, disable_flex_attn = disable_flex_attn, flex_attn_fn = flex_attn_fn)
|
|
446
471
|
|
|
447
472
|
value_residual = default(value_residual, values)
|
|
448
473
|
|
|
@@ -123,9 +123,12 @@ class MemoryMLP(Module):
|
|
|
123
123
|
class MemoryAttention(Module):
|
|
124
124
|
def __init__(
|
|
125
125
|
self,
|
|
126
|
-
dim
|
|
126
|
+
dim,
|
|
127
|
+
scale = 8.
|
|
127
128
|
):
|
|
128
129
|
super().__init__()
|
|
130
|
+
self.scale = scale
|
|
131
|
+
|
|
129
132
|
self.weights = nn.ParameterList([
|
|
130
133
|
nn.Parameter(torch.randn(dim, dim)), # queries
|
|
131
134
|
nn.Parameter(torch.randn(dim, dim)), # keys
|
|
@@ -143,6 +146,7 @@ class MemoryAttention(Module):
|
|
|
143
146
|
|
|
144
147
|
attn_out = F.scaled_dot_product_attention(
|
|
145
148
|
q, k, v,
|
|
149
|
+
scale = self.scale,
|
|
146
150
|
is_causal = True
|
|
147
151
|
)
|
|
148
152
|
|
|
@@ -174,6 +178,7 @@ class NeuralMemory(Module):
|
|
|
174
178
|
default_step_transform_max_lr = 1e-2,
|
|
175
179
|
pre_rmsnorm = True,
|
|
176
180
|
post_rmsnorm = True,
|
|
181
|
+
learned_mem_model_weights = True,
|
|
177
182
|
max_grad_norm: float | None = None,
|
|
178
183
|
use_accelerated_scan = False,
|
|
179
184
|
activation: Module | None = None,
|
|
@@ -212,6 +217,9 @@ class NeuralMemory(Module):
|
|
|
212
217
|
if not exists(model):
|
|
213
218
|
model = MemoryMLP(dim_head, **default_model_kwargs)
|
|
214
219
|
|
|
220
|
+
if not learned_mem_model_weights:
|
|
221
|
+
model.requires_grad_(False)
|
|
222
|
+
|
|
215
223
|
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
|
216
224
|
|
|
217
225
|
# the memory is the weights of the model
|
|
@@ -9,7 +9,7 @@ from torch.optim import Adam
|
|
|
9
9
|
from torch.nn import functional as F
|
|
10
10
|
from torch.utils.data import DataLoader, Dataset
|
|
11
11
|
|
|
12
|
-
from titans_pytorch
|
|
12
|
+
from titans_pytorch import MemoryAsContextTransformer
|
|
13
13
|
|
|
14
14
|
# constants
|
|
15
15
|
|
|
@@ -32,6 +32,7 @@ NUM_LONGTERM_MEM = 4
|
|
|
32
32
|
NEURAL_MEM_LAYERS = (2, 4)
|
|
33
33
|
WINDOW_SIZE = 32
|
|
34
34
|
KV_RECON_LOSS_WEIGHT = 0.
|
|
35
|
+
LEARNED_MEM_MODEL_WEIGHTS = True
|
|
35
36
|
RUN_NAME = f'mac - {NUM_LONGTERM_MEM} longterm mems, layers {NEURAL_MEM_LAYERS}'
|
|
36
37
|
|
|
37
38
|
# wandb experiment tracker
|
|
@@ -90,7 +91,7 @@ def base_decoding(
|
|
|
90
91
|
sample_num_times = max(0, seq_len - prompt_seq_len)
|
|
91
92
|
|
|
92
93
|
for _ in tqdm.tqdm(range(sample_num_times)):
|
|
93
|
-
logits = net(out)
|
|
94
|
+
logits = net(out, disable_flex_attn = True)
|
|
94
95
|
logits = logits[:, -1]
|
|
95
96
|
|
|
96
97
|
logits = min_p_filter(logits, min_p = min_p)
|
|
@@ -112,9 +113,11 @@ model = MemoryAsContextTransformer(
|
|
|
112
113
|
neural_memory_layers = NEURAL_MEM_LAYERS,
|
|
113
114
|
neural_memory_segment_len = WINDOW_SIZE // 2,
|
|
114
115
|
aux_kv_recon_loss_weight = KV_RECON_LOSS_WEIGHT,
|
|
116
|
+
use_flex_attn = True,
|
|
115
117
|
neural_memory_kwargs = dict(
|
|
116
118
|
dim_head = 64,
|
|
117
119
|
heads = 4,
|
|
120
|
+
learned_mem_model_weights = LEARNED_MEM_MODEL_WEIGHTS,
|
|
118
121
|
default_model_kwargs = dict(
|
|
119
122
|
depth = NEURAL_MEMORY_DEPTH,
|
|
120
123
|
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|