titans-pytorch 0.1.38__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/PKG-INFO +2 -2
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/README.md +1 -1
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/pyproject.toml +1 -1
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/tests/test_titans.py +3 -2
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/titans_pytorch/mac_transformer.py +7 -7
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/train_mac.py +2 -1
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/.gitignore +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/LICENSE +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/data/README.md +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/fig1.png +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/fig2.png +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.1.38 → titans_pytorch-0.2.0}/titans_pytorch/neural_memory.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
|
|
56
56
|
|
57
57
|
<img src="./fig1.png" width="400px"></img>
|
58
58
|
|
59
|
-
## Titans - Pytorch
|
59
|
+
## Titans - Pytorch
|
60
60
|
|
61
61
|
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
62
62
|
|
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
<img src="./fig1.png" width="400px"></img>
|
4
4
|
|
5
|
-
## Titans - Pytorch
|
5
|
+
## Titans - Pytorch
|
6
6
|
|
7
7
|
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
8
8
|
|
@@ -184,9 +184,10 @@ def test_mac(
|
|
184
184
|
assert logits.shape == (1, seq_len, 256)
|
185
185
|
|
186
186
|
@pytest.mark.parametrize('sliding', (False, True))
|
187
|
-
@pytest.mark.parametrize('mem_layers', (()))
|
187
|
+
@pytest.mark.parametrize('mem_layers', ((), None))
|
188
188
|
@pytest.mark.parametrize('longterm_mems', (0, 4, 16))
|
189
|
-
@pytest.mark.parametrize('prompt_len', (
|
189
|
+
@pytest.mark.parametrize('prompt_len', (4, 16))
|
190
|
+
@torch_default_dtype(torch.float64)
|
190
191
|
def test_mac_sampling(
|
191
192
|
sliding,
|
192
193
|
mem_layers,
|
@@ -111,6 +111,7 @@ def pad_and_segment_with_inverse(
|
|
111
111
|
seq,
|
112
112
|
segment_len,
|
113
113
|
fold_into_batch = True,
|
114
|
+
inverse_remove_pad = True
|
114
115
|
):
|
115
116
|
batch, seq_len = seq.shape[:2]
|
116
117
|
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
@@ -124,15 +125,12 @@ def pad_and_segment_with_inverse(
|
|
124
125
|
if fold_into_batch:
|
125
126
|
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
126
127
|
|
127
|
-
shape = seq.shape
|
128
|
-
|
129
128
|
def inverse(out):
|
130
|
-
unchanged_shape = out.shape == shape
|
131
129
|
|
132
130
|
if fold_into_batch:
|
133
131
|
out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
|
134
132
|
|
135
|
-
if needs_pad and
|
133
|
+
if needs_pad and inverse_remove_pad:
|
136
134
|
out = out[..., :-padding, :]
|
137
135
|
|
138
136
|
return out
|
@@ -714,7 +712,7 @@ class MemoryAsContextTransformer(Module):
|
|
714
712
|
|
715
713
|
# intersperse longterm memory
|
716
714
|
|
717
|
-
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
715
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len, inverse_remove_pad = False)
|
718
716
|
|
719
717
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
720
718
|
x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
|
@@ -856,7 +854,9 @@ class MemoryAsContextTransformer(Module):
|
|
856
854
|
|
857
855
|
next_kv_caches = next_kv_caches[..., -attn_window_size:, :]
|
858
856
|
|
859
|
-
|
857
|
+
kv_cache_length = next_kv_caches.shape[-2]
|
858
|
+
|
859
|
+
if not self.sliding_window_attn and divisible_by(kv_cache_length, attn_window_size):
|
860
860
|
next_kv_caches = next_kv_caches[..., 0:0, :]
|
861
861
|
|
862
862
|
next_cache = (
|
@@ -878,7 +878,7 @@ class MemoryAsContextTransformer(Module):
|
|
878
878
|
|
879
879
|
if not is_inferencing:
|
880
880
|
|
881
|
-
x, inverse_segment = pad_and_segment_with_inverse(x, attn_window_size)
|
881
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, attn_window_size, inverse_remove_pad = False)
|
882
882
|
|
883
883
|
x, _ = inverse_pack_mems(x)
|
884
884
|
|
@@ -53,6 +53,7 @@ WANDB_ONLINE = False # turn this on to pipe experiment to cloud
|
|
53
53
|
|
54
54
|
USE_ACCELERATED_SCAN = True
|
55
55
|
USE_FLEX_ATTN = True
|
56
|
+
USE_FAST_INFERENCE = False
|
56
57
|
|
57
58
|
# wandb experiment tracker
|
58
59
|
|
@@ -163,6 +164,6 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
|
|
163
164
|
prime = decode_tokens(inp)
|
164
165
|
print(f'%s \n\n %s', (prime, '*' * 100))
|
165
166
|
|
166
|
-
sample = model.sample(inp[None, ...], GENERATE_LENGTH)
|
167
|
+
sample = model.sample(inp[None, ...], GENERATE_LENGTH, use_cache = True)
|
167
168
|
output_str = decode_tokens(sample[0])
|
168
169
|
print(output_str)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|