titans-pytorch 0.3.21__tar.gz → 0.3.22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/PKG-INFO +1 -1
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/pyproject.toml +1 -1
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/tests/test_titans.py +3 -0
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/titans_pytorch/neural_memory.py +24 -9
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/.gitignore +0 -0
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/LICENSE +0 -0
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/README.md +0 -0
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/data/README.md +0 -0
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/data/enwik8.gz +0 -0
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/fig1.png +0 -0
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/fig2.png +0 -0
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.3.21 → titans_pytorch-0.3.22}/train_mac.py +0 -0
@@ -30,6 +30,7 @@ def torch_default_dtype(dtype):
|
|
30
30
|
@pytest.mark.parametrize('momentum', (False, True))
|
31
31
|
@pytest.mark.parametrize('qk_rmsnorm', (False, True))
|
32
32
|
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
33
|
+
@pytest.mark.parametrize('num_kv_per_token', (1, 2))
|
33
34
|
@pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
|
34
35
|
@pytest.mark.parametrize('per_head_learned_parameters', (False, True))
|
35
36
|
def test_titans(
|
@@ -40,6 +41,7 @@ def test_titans(
|
|
40
41
|
momentum,
|
41
42
|
qk_rmsnorm,
|
42
43
|
max_grad_norm,
|
44
|
+
num_kv_per_token,
|
43
45
|
per_parameter_lr_modulation,
|
44
46
|
per_head_learned_parameters
|
45
47
|
):
|
@@ -49,6 +51,7 @@ def test_titans(
|
|
49
51
|
activation = nn.SiLU() if silu else None,
|
50
52
|
attn_pool_chunks = attn_pool_chunks,
|
51
53
|
max_grad_norm = max_grad_norm,
|
54
|
+
num_kv_per_token = num_kv_per_token,
|
52
55
|
momentum = momentum,
|
53
56
|
qk_rmsnorm = qk_rmsnorm,
|
54
57
|
per_parameter_lr_modulation = per_parameter_lr_modulation,
|
@@ -35,6 +35,7 @@ d - feature dimension
|
|
35
35
|
c - intra-chunk
|
36
36
|
w - num memory network weight parameters
|
37
37
|
o - momentum orders
|
38
|
+
u - key / value updates - allowing a token to emit multiple key / values
|
38
39
|
"""
|
39
40
|
|
40
41
|
LinearNoBias = partial(Linear, bias = False)
|
@@ -231,6 +232,7 @@ class NeuralMemory(Module):
|
|
231
232
|
momentum_order = 1,
|
232
233
|
learned_momentum_combine = False,
|
233
234
|
learned_combine_include_zeroth = False,
|
235
|
+
num_kv_per_token = 1, # whether a single token can do multiple updates to the memory model
|
234
236
|
qkv_receives_diff_views = False, # to address an issue raised by a phd student (who will be credited if experiments are green). basically the issue raised is that the memory MLP is only learning Wk @ Wv linear mapping and that may not be expressive enough. we will use hyper connections to allow the network to choose different previous layer inputs as keys / values and see if that does anything
|
235
237
|
pre_rmsnorm = True,
|
236
238
|
post_rmsnorm = False,
|
@@ -363,11 +365,22 @@ class NeuralMemory(Module):
|
|
363
365
|
|
364
366
|
# keys and values for storing to the model
|
365
367
|
|
366
|
-
|
367
|
-
|
368
|
+
assert num_kv_per_token > 0
|
369
|
+
|
370
|
+
self.to_keys = Sequential(
|
371
|
+
LinearNoBias(dim, dim_inner * num_kv_per_token),
|
372
|
+
activation,
|
373
|
+
)
|
374
|
+
|
375
|
+
self.to_values = Sequential(
|
376
|
+
LinearNoBias(dim, dim_inner * num_kv_per_token),
|
377
|
+
activation,
|
378
|
+
)
|
368
379
|
|
369
380
|
self.store_memory_loss_fn = store_memory_loss_fn
|
370
381
|
|
382
|
+
self.num_kv_per_token = num_kv_per_token
|
383
|
+
|
371
384
|
# `chunk_size` refers to chunk size used for storing to memory model weights
|
372
385
|
|
373
386
|
chunk_size = self.store_chunk_size
|
@@ -384,8 +397,8 @@ class NeuralMemory(Module):
|
|
384
397
|
# learned adaptive learning rate
|
385
398
|
|
386
399
|
self.to_adaptive_step = Sequential(
|
387
|
-
nn.Linear(dim, heads),
|
388
|
-
Rearrange('b n h -> (b h) n')
|
400
|
+
nn.Linear(dim, heads * num_kv_per_token),
|
401
|
+
Rearrange('b n (h u) -> (b h) (n u)', u = num_kv_per_token)
|
389
402
|
)
|
390
403
|
|
391
404
|
if not exists(adaptive_step_transform):
|
@@ -518,7 +531,7 @@ class NeuralMemory(Module):
|
|
518
531
|
|
519
532
|
# shapes and variables
|
520
533
|
|
521
|
-
heads, chunk_size = self.heads, self.store_chunk_size
|
534
|
+
heads, chunk_size, num_updates = self.heads, self.store_chunk_size, self.num_kv_per_token
|
522
535
|
|
523
536
|
# curtail sequence by multiple of the chunk size
|
524
537
|
# only a complete chunk of the sequence provides the memory for the next chunk
|
@@ -587,15 +600,17 @@ class NeuralMemory(Module):
|
|
587
600
|
|
588
601
|
batch = keys.shape[0]
|
589
602
|
|
603
|
+
# take care of chunking
|
604
|
+
|
605
|
+
keys, values = tuple(rearrange(t, 'b h (n c) (u d) -> (b h n) (c u) d', c = chunk_size, u = num_updates) for t in (keys, values))
|
606
|
+
|
590
607
|
# maybe qk rmsnorm
|
591
608
|
|
592
609
|
keys = self.k_norm(keys)
|
593
610
|
|
594
|
-
#
|
595
|
-
|
596
|
-
keys, values = tuple(rearrange(t, 'b h (n c) d -> (b h n) c d', c = chunk_size) for t in (keys, values))
|
611
|
+
# adaptive lr
|
597
612
|
|
598
|
-
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
|
613
|
+
adaptive_lr = rearrange(adaptive_lr, 'b (n c u) -> (b n) (c u)', c = chunk_size, u = num_updates)
|
599
614
|
|
600
615
|
# maybe add previous layer weight
|
601
616
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|