titans-pytorch 0.1.38__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/PKG-INFO +2 -2
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/README.md +1 -1
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/pyproject.toml +1 -1
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/tests/test_titans.py +3 -2
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/titans_pytorch/mac_transformer.py +12 -10
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/train_mac.py +2 -1
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/.gitignore +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/LICENSE +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/data/README.md +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/fig1.png +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/fig2.png +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.1}/titans_pytorch/neural_memory.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.1
|
3
|
+
Version: 0.2.1
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
|
|
56
56
|
|
57
57
|
<img src="./fig1.png" width="400px"></img>
|
58
58
|
|
59
|
-
## Titans - Pytorch
|
59
|
+
## Titans - Pytorch
|
60
60
|
|
61
61
|
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
62
62
|
|
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
<img src="./fig1.png" width="400px"></img>
|
4
4
|
|
5
|
-
## Titans - Pytorch
|
5
|
+
## Titans - Pytorch
|
6
6
|
|
7
7
|
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
8
8
|
|
@@ -184,9 +184,10 @@ def test_mac(
|
|
184
184
|
assert logits.shape == (1, seq_len, 256)
|
185
185
|
|
186
186
|
@pytest.mark.parametrize('sliding', (False, True))
|
187
|
-
@pytest.mark.parametrize('mem_layers', (()))
|
187
|
+
@pytest.mark.parametrize('mem_layers', ((), None))
|
188
188
|
@pytest.mark.parametrize('longterm_mems', (0, 4, 16))
|
189
|
-
@pytest.mark.parametrize('prompt_len', (
|
189
|
+
@pytest.mark.parametrize('prompt_len', (4, 16))
|
190
|
+
@torch_default_dtype(torch.float64)
|
190
191
|
def test_mac_sampling(
|
191
192
|
sliding,
|
192
193
|
mem_layers,
|
@@ -111,6 +111,7 @@ def pad_and_segment_with_inverse(
|
|
111
111
|
seq,
|
112
112
|
segment_len,
|
113
113
|
fold_into_batch = True,
|
114
|
+
inverse_remove_pad = True
|
114
115
|
):
|
115
116
|
batch, seq_len = seq.shape[:2]
|
116
117
|
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
@@ -124,15 +125,12 @@ def pad_and_segment_with_inverse(
|
|
124
125
|
if fold_into_batch:
|
125
126
|
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
126
127
|
|
127
|
-
shape = seq.shape
|
128
|
-
|
129
128
|
def inverse(out):
|
130
|
-
unchanged_shape = out.shape == shape
|
131
129
|
|
132
130
|
if fold_into_batch:
|
133
131
|
out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
|
134
132
|
|
135
|
-
if needs_pad and
|
133
|
+
if needs_pad and inverse_remove_pad:
|
136
134
|
out = out[..., :-padding, :]
|
137
135
|
|
138
136
|
return out
|
@@ -493,7 +491,8 @@ class MemoryAsContextTransformer(Module):
|
|
493
491
|
aux_kv_recon_loss_weight = 0.,
|
494
492
|
use_flex_attn = False,
|
495
493
|
sliding_window_attn = False,
|
496
|
-
weight_tie_memory_model = False
|
494
|
+
weight_tie_memory_model = False,
|
495
|
+
prev_neural_mem_update_for_weights = None
|
497
496
|
):
|
498
497
|
super().__init__()
|
499
498
|
|
@@ -535,6 +534,7 @@ class MemoryAsContextTransformer(Module):
|
|
535
534
|
assert exists(neural_memory_model), '`neural_memory_model` must be explicitly set'
|
536
535
|
|
537
536
|
self.weight_tie_memory_model = weight_tie_memory_model
|
537
|
+
self.prev_neural_mem_update_for_weights = default(prev_neural_mem_update_for_weights, weight_tie_memory_model)
|
538
538
|
|
539
539
|
# value residual learning for neural memory
|
540
540
|
|
@@ -704,7 +704,7 @@ class MemoryAsContextTransformer(Module):
|
|
704
704
|
|
705
705
|
# math
|
706
706
|
|
707
|
-
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size,
|
707
|
+
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, prev_neural_mem_update_for_weights = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.prev_neural_mem_update_for_weights
|
708
708
|
|
709
709
|
seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
|
710
710
|
|
@@ -714,7 +714,7 @@ class MemoryAsContextTransformer(Module):
|
|
714
714
|
|
715
715
|
# intersperse longterm memory
|
716
716
|
|
717
|
-
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
717
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len, inverse_remove_pad = False)
|
718
718
|
|
719
719
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
720
720
|
x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
|
@@ -816,7 +816,7 @@ class MemoryAsContextTransformer(Module):
|
|
816
816
|
if self.mem_add_value_residual:
|
817
817
|
mem_value_residual = next_mem_value_residual
|
818
818
|
|
819
|
-
if
|
819
|
+
if prev_neural_mem_update_for_weights:
|
820
820
|
neural_memory_updates = next_neural_mem_cache.updates
|
821
821
|
|
822
822
|
if self.gate_attn_output:
|
@@ -856,7 +856,9 @@ class MemoryAsContextTransformer(Module):
|
|
856
856
|
|
857
857
|
next_kv_caches = next_kv_caches[..., -attn_window_size:, :]
|
858
858
|
|
859
|
-
|
859
|
+
kv_cache_length = next_kv_caches.shape[-2]
|
860
|
+
|
861
|
+
if not self.sliding_window_attn and divisible_by(kv_cache_length, attn_window_size):
|
860
862
|
next_kv_caches = next_kv_caches[..., 0:0, :]
|
861
863
|
|
862
864
|
next_cache = (
|
@@ -878,7 +880,7 @@ class MemoryAsContextTransformer(Module):
|
|
878
880
|
|
879
881
|
if not is_inferencing:
|
880
882
|
|
881
|
-
x, inverse_segment = pad_and_segment_with_inverse(x, attn_window_size)
|
883
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, attn_window_size, inverse_remove_pad = False)
|
882
884
|
|
883
885
|
x, _ = inverse_pack_mems(x)
|
884
886
|
|
@@ -53,6 +53,7 @@ WANDB_ONLINE = False # turn this on to pipe experiment to cloud
|
|
53
53
|
|
54
54
|
USE_ACCELERATED_SCAN = True
|
55
55
|
USE_FLEX_ATTN = True
|
56
|
+
USE_FAST_INFERENCE = False
|
56
57
|
|
57
58
|
# wandb experiment tracker
|
58
59
|
|
@@ -163,6 +164,6 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
|
|
163
164
|
prime = decode_tokens(inp)
|
164
165
|
print(f'%s \n\n %s', (prime, '*' * 100))
|
165
166
|
|
166
|
-
sample = model.sample(inp[None, ...], GENERATE_LENGTH)
|
167
|
+
sample = model.sample(inp[None, ...], GENERATE_LENGTH, use_cache = USE_FAST_INFERENCE)
|
167
168
|
output_str = decode_tokens(sample[0])
|
168
169
|
print(output_str)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|