titans-pytorch 0.1.14__tar.gz → 0.1.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/PKG-INFO +2 -2
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/pyproject.toml +2 -2
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/tests/test_titans.py +20 -3
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/titans_pytorch/mac_transformer.py +41 -5
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/titans_pytorch/titans.py +45 -34
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/.gitignore +0 -0
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/LICENSE +0 -0
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/README.md +0 -0
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/data/README.md +0 -0
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/fig1.png +0 -0
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/fig2.png +0 -0
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.1.14 → titans_pytorch-0.1.17}/train_mac.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.17
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -35,7 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
|
38
|
-
Requires-Dist: axial-positional-embedding>=0.3.
|
|
38
|
+
Requires-Dist: axial-positional-embedding>=0.3.9
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
41
|
Requires-Dist: hyper-connections>=0.1.8
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "titans-pytorch"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.17"
|
|
4
4
|
description = "Titans"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -26,7 +26,7 @@ classifiers=[
|
|
|
26
26
|
|
|
27
27
|
dependencies = [
|
|
28
28
|
"accelerated-scan>=0.2.0",
|
|
29
|
-
"axial_positional_embedding>=0.3.
|
|
29
|
+
"axial_positional_embedding>=0.3.9",
|
|
30
30
|
"einops>=0.8.0",
|
|
31
31
|
"einx>=0.3.0",
|
|
32
32
|
"hyper-connections>=0.1.8",
|
|
@@ -3,7 +3,7 @@ from torch import nn
|
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
5
|
from titans_pytorch import NeuralMemory
|
|
6
|
-
from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention
|
|
6
|
+
from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention, MemoryAsContextTransformer
|
|
7
7
|
|
|
8
8
|
def exists(v):
|
|
9
9
|
return v is not None
|
|
@@ -92,8 +92,6 @@ def test_mac(
|
|
|
92
92
|
num_longterm_mem_tokens,
|
|
93
93
|
neural_mem_gate_attn_output
|
|
94
94
|
):
|
|
95
|
-
from titans_pytorch.mac_transformer import MemoryAsContextTransformer
|
|
96
|
-
|
|
97
95
|
transformer = MemoryAsContextTransformer(
|
|
98
96
|
num_tokens = 256,
|
|
99
97
|
dim = 256,
|
|
@@ -109,6 +107,25 @@ def test_mac(
|
|
|
109
107
|
logits = transformer(x)
|
|
110
108
|
assert logits.shape == (1, seq_len, 256)
|
|
111
109
|
|
|
110
|
+
def test_mac_sampling():
|
|
111
|
+
transformer = MemoryAsContextTransformer(
|
|
112
|
+
num_tokens = 256,
|
|
113
|
+
dim = 256,
|
|
114
|
+
depth = 2,
|
|
115
|
+
segment_len = 32,
|
|
116
|
+
num_persist_mem_tokens = 4,
|
|
117
|
+
num_longterm_mem_tokens = 16,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
ids = torch.randint(0, 256, (1, 1023))
|
|
121
|
+
|
|
122
|
+
# after much training
|
|
123
|
+
|
|
124
|
+
sampled = transformer.sample(ids[:, :4], 53, use_cache = False, temperature = 0.)
|
|
125
|
+
sampled_with_cache = transformer.sample(ids[:, :4], 53, use_cache = True, temperature = 0.)
|
|
126
|
+
|
|
127
|
+
assert torch.allclose(sampled, sampled_with_cache)
|
|
128
|
+
|
|
112
129
|
@pytest.mark.parametrize('seq_len', (1023, 17))
|
|
113
130
|
@pytest.mark.parametrize('sliding', (True, False))
|
|
114
131
|
def test_flex(
|
|
@@ -537,7 +537,8 @@ class MemoryAsContextTransformer(Module):
|
|
|
537
537
|
filter_kwargs: dict = dict(
|
|
538
538
|
min_p = 0.1,
|
|
539
539
|
),
|
|
540
|
-
show_progress = True
|
|
540
|
+
show_progress = True,
|
|
541
|
+
use_cache = False
|
|
541
542
|
):
|
|
542
543
|
was_training = self.training
|
|
543
544
|
self.eval()
|
|
@@ -547,8 +548,37 @@ class MemoryAsContextTransformer(Module):
|
|
|
547
548
|
|
|
548
549
|
iter_wrap = tqdm.tqdm if show_progress else identity
|
|
549
550
|
|
|
551
|
+
# cache for axial pos, attention, and neural memory
|
|
552
|
+
|
|
553
|
+
cache = None
|
|
554
|
+
factorized_pos_emb = None
|
|
555
|
+
|
|
556
|
+
# precompute factorized pos emb
|
|
557
|
+
|
|
558
|
+
if use_cache:
|
|
559
|
+
round_up_seq_len = round_up_multiple(seq_len, self.segment_len)
|
|
560
|
+
longterm_mem_lens = (round_up_seq_len // self.segment_len) * self.num_longterm_mem_tokens
|
|
561
|
+
seq_len_with_mem = round_up_seq_len + longterm_mem_lens
|
|
562
|
+
|
|
563
|
+
axial_dims = self.axial_pos_emb.maybe_derive_outer_dim(seq_len_with_mem, (self.neural_memory_segment_len,))
|
|
564
|
+
|
|
565
|
+
factorized_pos_emb = self.axial_pos_emb(axial_dims, return_factorized = True)
|
|
566
|
+
|
|
567
|
+
# sample
|
|
568
|
+
|
|
550
569
|
for _ in iter_wrap(range(sample_num_times)):
|
|
551
|
-
|
|
570
|
+
|
|
571
|
+
logits, next_cache = self.forward(
|
|
572
|
+
out,
|
|
573
|
+
disable_flex_attn = True,
|
|
574
|
+
cache = cache,
|
|
575
|
+
return_cache = True,
|
|
576
|
+
factorized_pos_emb = factorized_pos_emb
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
if use_cache:
|
|
580
|
+
cache = next_cache
|
|
581
|
+
|
|
552
582
|
logits = logits[:, -1]
|
|
553
583
|
|
|
554
584
|
logits = filter_fn(logits, **filter_kwargs)
|
|
@@ -565,7 +595,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
565
595
|
x,
|
|
566
596
|
return_loss = False,
|
|
567
597
|
return_loss_breakdown = False,
|
|
568
|
-
disable_flex_attn = False
|
|
598
|
+
disable_flex_attn = False,
|
|
599
|
+
cache = None,
|
|
600
|
+
return_cache = False,
|
|
601
|
+
factorized_pos_emb = None
|
|
569
602
|
):
|
|
570
603
|
|
|
571
604
|
if return_loss:
|
|
@@ -593,7 +626,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
593
626
|
# apply axial positional embedding
|
|
594
627
|
# so intra and inter segment can be more easily discerned by the network
|
|
595
628
|
|
|
596
|
-
pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,))
|
|
629
|
+
pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,), factorized = factorized_pos_emb)
|
|
597
630
|
|
|
598
631
|
x = x + pos_emb
|
|
599
632
|
|
|
@@ -651,7 +684,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
651
684
|
logits = self.to_logits(x)
|
|
652
685
|
|
|
653
686
|
if not return_loss:
|
|
654
|
-
|
|
687
|
+
if not return_cache:
|
|
688
|
+
return logits
|
|
689
|
+
|
|
690
|
+
return logits, cache
|
|
655
691
|
|
|
656
692
|
ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
|
657
693
|
|
|
@@ -301,6 +301,45 @@ class MemoryAttention(Module):
|
|
|
301
301
|
|
|
302
302
|
return out
|
|
303
303
|
|
|
304
|
+
# associative scan wrapper
|
|
305
|
+
|
|
306
|
+
class AssocScan(Module):
|
|
307
|
+
def __init__(
|
|
308
|
+
self,
|
|
309
|
+
use_accelerated = False
|
|
310
|
+
):
|
|
311
|
+
super().__init__()
|
|
312
|
+
self.use_accelerated = use_accelerated
|
|
313
|
+
|
|
314
|
+
def forward(self, gates, inputs):
|
|
315
|
+
|
|
316
|
+
if not self.use_accelerated:
|
|
317
|
+
_, outputs = associative_scan(binary_operator, (gates, inputs))
|
|
318
|
+
return outputs
|
|
319
|
+
|
|
320
|
+
from accelerated_scan.triton import scan as triton_scan
|
|
321
|
+
from accelerated_scan.warp import scan as warp_scan
|
|
322
|
+
|
|
323
|
+
scan = triton_scan if gates.is_cuda else warp_scan
|
|
324
|
+
|
|
325
|
+
def accelerate_scan_fn(gates, inputs):
|
|
326
|
+
gates = gates.expand_as(inputs)
|
|
327
|
+
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
|
328
|
+
|
|
329
|
+
seq_len = gates.shape[-1]
|
|
330
|
+
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
|
331
|
+
|
|
332
|
+
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
|
333
|
+
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
|
334
|
+
|
|
335
|
+
outputs = scan(gates.contiguous(), inputs.contiguous())
|
|
336
|
+
|
|
337
|
+
outputs = outputs[..., :seq_len]
|
|
338
|
+
outputs = rearrange(outputs, 'b d n -> b n d')
|
|
339
|
+
return outputs
|
|
340
|
+
|
|
341
|
+
return accelerate_scan_fn(gates, inputs)
|
|
342
|
+
|
|
304
343
|
# main neural memory
|
|
305
344
|
|
|
306
345
|
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
|
|
@@ -339,6 +378,10 @@ class NeuralMemory(Module):
|
|
|
339
378
|
|
|
340
379
|
self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
|
|
341
380
|
|
|
381
|
+
# associative scan
|
|
382
|
+
|
|
383
|
+
self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan)
|
|
384
|
+
|
|
342
385
|
# norms
|
|
343
386
|
|
|
344
387
|
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
|
@@ -564,38 +607,6 @@ class NeuralMemory(Module):
|
|
|
564
607
|
|
|
565
608
|
surprises = grads.apply(lambda t: -t)
|
|
566
609
|
|
|
567
|
-
# determine scan function
|
|
568
|
-
|
|
569
|
-
def default_associative_scan(gates, inputs):
|
|
570
|
-
_, outputs = associative_scan(binary_operator, (gates, inputs))
|
|
571
|
-
return outputs
|
|
572
|
-
|
|
573
|
-
if self.use_accelerated_scan:
|
|
574
|
-
from accelerated_scan.triton import scan as triton_scan
|
|
575
|
-
from accelerated_scan.warp import scan as warp_scan
|
|
576
|
-
|
|
577
|
-
scan = triton_scan if seq.is_cuda else warp_scan
|
|
578
|
-
|
|
579
|
-
def accelerate_scan_fn(gates, inputs):
|
|
580
|
-
gates = gates.expand_as(inputs)
|
|
581
|
-
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
|
582
|
-
|
|
583
|
-
seq_len = gates.shape[-1]
|
|
584
|
-
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
|
585
|
-
|
|
586
|
-
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
|
587
|
-
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
|
588
|
-
|
|
589
|
-
outputs = scan(gates.contiguous(), inputs.contiguous())
|
|
590
|
-
|
|
591
|
-
outputs = outputs[..., :seq_len]
|
|
592
|
-
outputs = rearrange(outputs, 'b d n -> b n d')
|
|
593
|
-
return outputs
|
|
594
|
-
|
|
595
|
-
scan_fn = accelerate_scan_fn
|
|
596
|
-
else:
|
|
597
|
-
scan_fn = default_associative_scan
|
|
598
|
-
|
|
599
610
|
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
|
600
611
|
|
|
601
612
|
next_momentum = TensorDict() if has_momentum else None
|
|
@@ -610,12 +621,12 @@ class NeuralMemory(Module):
|
|
|
610
621
|
# derive momentum with associative scan - eq (10)
|
|
611
622
|
|
|
612
623
|
if has_momentum:
|
|
613
|
-
update =
|
|
624
|
+
update = self.assoc_scan(adaptive_momentum, surprise) # momentum is S / surprise in the paper
|
|
614
625
|
momentum = update
|
|
615
626
|
|
|
616
627
|
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
|
617
628
|
|
|
618
|
-
update =
|
|
629
|
+
update = self.assoc_scan(1. - decay_factor, update)
|
|
619
630
|
|
|
620
631
|
updates[param_name] = inverse_pack(update)
|
|
621
632
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|