titans-pytorch 0.1.12__tar.gz → 0.1.15__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/PKG-INFO +2 -2
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/pyproject.toml +2 -2
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/tests/test_titans.py +3 -0
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/titans_pytorch/titans.py +66 -46
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/train_mac.py +2 -0
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/.gitignore +0 -0
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/LICENSE +0 -0
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/README.md +0 -0
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/data/README.md +0 -0
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/fig1.png +0 -0
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/fig2.png +0 -0
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.1.12 → titans_pytorch-0.1.15}/titans_pytorch/mac_transformer.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.15
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -35,7 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
|
38
|
-
Requires-Dist: axial-positional-embedding>=0.3.
|
|
38
|
+
Requires-Dist: axial-positional-embedding>=0.3.9
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
41
|
Requires-Dist: hyper-connections>=0.1.8
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "titans-pytorch"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.15"
|
|
4
4
|
description = "Titans"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -26,7 +26,7 @@ classifiers=[
|
|
|
26
26
|
|
|
27
27
|
dependencies = [
|
|
28
28
|
"accelerated-scan>=0.2.0",
|
|
29
|
-
"axial_positional_embedding>=0.3.
|
|
29
|
+
"axial_positional_embedding>=0.3.9",
|
|
30
30
|
"einops>=0.8.0",
|
|
31
31
|
"einx>=0.3.0",
|
|
32
32
|
"hyper-connections>=0.1.8",
|
|
@@ -12,6 +12,7 @@ def exists(v):
|
|
|
12
12
|
@pytest.mark.parametrize('silu', (False, True))
|
|
13
13
|
@pytest.mark.parametrize('learned_mem_model_weights', (False, True))
|
|
14
14
|
@pytest.mark.parametrize('attn_pool_chunks', (False, True))
|
|
15
|
+
@pytest.mark.parametrize('momentum', (False, True))
|
|
15
16
|
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
|
16
17
|
@pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
|
|
17
18
|
def test_titans(
|
|
@@ -19,6 +20,7 @@ def test_titans(
|
|
|
19
20
|
silu,
|
|
20
21
|
learned_mem_model_weights,
|
|
21
22
|
attn_pool_chunks,
|
|
23
|
+
momentum,
|
|
22
24
|
max_grad_norm,
|
|
23
25
|
per_parameter_lr_modulation
|
|
24
26
|
):
|
|
@@ -28,6 +30,7 @@ def test_titans(
|
|
|
28
30
|
activation = nn.SiLU() if silu else None,
|
|
29
31
|
attn_pool_chunks = attn_pool_chunks,
|
|
30
32
|
max_grad_norm = max_grad_norm,
|
|
33
|
+
momentum = momentum,
|
|
31
34
|
per_parameter_lr_modulation = per_parameter_lr_modulation,
|
|
32
35
|
learned_mem_model_weights = learned_mem_model_weights
|
|
33
36
|
)
|
|
@@ -301,6 +301,45 @@ class MemoryAttention(Module):
|
|
|
301
301
|
|
|
302
302
|
return out
|
|
303
303
|
|
|
304
|
+
# associative scan wrapper
|
|
305
|
+
|
|
306
|
+
class AssocScan(Module):
|
|
307
|
+
def __init__(
|
|
308
|
+
self,
|
|
309
|
+
use_accelerated = False
|
|
310
|
+
):
|
|
311
|
+
super().__init__()
|
|
312
|
+
self.use_accelerated = use_accelerated
|
|
313
|
+
|
|
314
|
+
def forward(self, gates, inputs):
|
|
315
|
+
|
|
316
|
+
if not self.use_accelerated:
|
|
317
|
+
_, outputs = associative_scan(binary_operator, (gates, inputs))
|
|
318
|
+
return outputs
|
|
319
|
+
|
|
320
|
+
from accelerated_scan.triton import scan as triton_scan
|
|
321
|
+
from accelerated_scan.warp import scan as warp_scan
|
|
322
|
+
|
|
323
|
+
scan = triton_scan if gates.is_cuda else warp_scan
|
|
324
|
+
|
|
325
|
+
def accelerate_scan_fn(gates, inputs):
|
|
326
|
+
gates = gates.expand_as(inputs)
|
|
327
|
+
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
|
328
|
+
|
|
329
|
+
seq_len = gates.shape[-1]
|
|
330
|
+
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
|
331
|
+
|
|
332
|
+
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
|
333
|
+
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
|
334
|
+
|
|
335
|
+
outputs = scan(gates.contiguous(), inputs.contiguous())
|
|
336
|
+
|
|
337
|
+
outputs = outputs[..., :seq_len]
|
|
338
|
+
outputs = rearrange(outputs, 'b d n -> b n d')
|
|
339
|
+
return outputs
|
|
340
|
+
|
|
341
|
+
return accelerate_scan_fn(gates, inputs)
|
|
342
|
+
|
|
304
343
|
# main neural memory
|
|
305
344
|
|
|
306
345
|
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
|
|
@@ -323,6 +362,7 @@ class NeuralMemory(Module):
|
|
|
323
362
|
per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
|
|
324
363
|
max_mem_layer_modulation = 1e1, # max of 10.
|
|
325
364
|
attn_pool_chunks = False,
|
|
365
|
+
momentum = True,
|
|
326
366
|
pre_rmsnorm = True,
|
|
327
367
|
post_rmsnorm = True,
|
|
328
368
|
learned_mem_model_weights = True,
|
|
@@ -338,6 +378,10 @@ class NeuralMemory(Module):
|
|
|
338
378
|
|
|
339
379
|
self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
|
|
340
380
|
|
|
381
|
+
# associative scan
|
|
382
|
+
|
|
383
|
+
self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan)
|
|
384
|
+
|
|
341
385
|
# norms
|
|
342
386
|
|
|
343
387
|
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
|
@@ -423,7 +467,7 @@ class NeuralMemory(Module):
|
|
|
423
467
|
self.to_momentum = Sequential(
|
|
424
468
|
LinearNoBias(dim, heads),
|
|
425
469
|
Rearrange('b n h -> (b h) n 1')
|
|
426
|
-
)
|
|
470
|
+
) if momentum else None
|
|
427
471
|
|
|
428
472
|
self.to_adaptive_step = Sequential(
|
|
429
473
|
LinearNoBias(dim, heads),
|
|
@@ -462,12 +506,15 @@ class NeuralMemory(Module):
|
|
|
462
506
|
|
|
463
507
|
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
|
464
508
|
|
|
465
|
-
def init_weights_and_momentum(self):
|
|
509
|
+
def init_weights_and_momentum(self, zero_weights = False):
|
|
466
510
|
params = TensorDict(dict(self.memory_model.named_parameters()))
|
|
467
511
|
|
|
468
|
-
init_weights = params
|
|
512
|
+
init_weights = params
|
|
469
513
|
init_momentum = params.clone().zero_()
|
|
470
514
|
|
|
515
|
+
if zero_weights:
|
|
516
|
+
init_weights = params.clone().zero_()
|
|
517
|
+
|
|
471
518
|
return init_weights, init_momentum
|
|
472
519
|
|
|
473
520
|
def init_empty_memory_embed(self, batch, seq_len):
|
|
@@ -497,14 +544,10 @@ class NeuralMemory(Module):
|
|
|
497
544
|
|
|
498
545
|
seq = seq[:, :round_down_seq_len]
|
|
499
546
|
|
|
500
|
-
#
|
|
501
|
-
|
|
502
|
-
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
|
547
|
+
# get the weights of the memory network
|
|
503
548
|
|
|
504
549
|
past_state = tuple(TensorDict(d) for d in past_state)
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
curr_weights = curr_weights + past_weights
|
|
550
|
+
curr_weights, past_momentum = past_state
|
|
508
551
|
|
|
509
552
|
# derive learned hparams for optimization of memory network
|
|
510
553
|
|
|
@@ -513,10 +556,13 @@ class NeuralMemory(Module):
|
|
|
513
556
|
|
|
514
557
|
chunked_seq = self.reduce_to_chunk_rep(seq, chunk_size = chunk_size)
|
|
515
558
|
|
|
516
|
-
adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
|
|
517
559
|
decay_factor = self.to_decay_factor(chunked_seq).sigmoid()
|
|
518
560
|
|
|
519
561
|
need_layer_lr_mod = exists(self.to_layer_modulation)
|
|
562
|
+
has_momentum = exists(self.to_momentum)
|
|
563
|
+
|
|
564
|
+
if has_momentum:
|
|
565
|
+
adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
|
|
520
566
|
|
|
521
567
|
if need_layer_lr_mod:
|
|
522
568
|
layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation
|
|
@@ -561,57 +607,31 @@ class NeuralMemory(Module):
|
|
|
561
607
|
|
|
562
608
|
surprises = grads.apply(lambda t: -t)
|
|
563
609
|
|
|
564
|
-
# determine scan function
|
|
565
|
-
|
|
566
|
-
def default_associative_scan(gates, inputs):
|
|
567
|
-
_, outputs = associative_scan(binary_operator, (gates, inputs))
|
|
568
|
-
return outputs
|
|
569
|
-
|
|
570
|
-
if self.use_accelerated_scan:
|
|
571
|
-
from accelerated_scan.triton import scan as triton_scan
|
|
572
|
-
from accelerated_scan.warp import scan as warp_scan
|
|
573
|
-
|
|
574
|
-
scan = triton_scan if seq.is_cuda else warp_scan
|
|
575
|
-
|
|
576
|
-
def accelerate_scan_fn(gates, inputs):
|
|
577
|
-
gates = gates.expand_as(inputs)
|
|
578
|
-
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
|
579
|
-
|
|
580
|
-
seq_len = gates.shape[-1]
|
|
581
|
-
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
|
582
|
-
|
|
583
|
-
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
|
584
|
-
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
|
585
|
-
|
|
586
|
-
outputs = scan(gates.contiguous(), inputs.contiguous())
|
|
587
|
-
|
|
588
|
-
outputs = outputs[..., :seq_len]
|
|
589
|
-
outputs = rearrange(outputs, 'b d n -> b n d')
|
|
590
|
-
return outputs
|
|
591
|
-
|
|
592
|
-
scan_fn = accelerate_scan_fn
|
|
593
|
-
else:
|
|
594
|
-
scan_fn = default_associative_scan
|
|
595
|
-
|
|
596
610
|
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
|
597
611
|
|
|
598
|
-
next_momentum = TensorDict()
|
|
612
|
+
next_momentum = TensorDict() if has_momentum else None
|
|
599
613
|
updates = TensorDict()
|
|
600
614
|
|
|
601
615
|
for param_name, surprise in surprises.items():
|
|
602
616
|
|
|
603
617
|
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
|
|
604
618
|
|
|
619
|
+
update = surprise
|
|
620
|
+
|
|
605
621
|
# derive momentum with associative scan - eq (10)
|
|
606
622
|
|
|
607
|
-
|
|
623
|
+
if has_momentum:
|
|
624
|
+
update = self.assoc_scan(adaptive_momentum, surprise) # momentum is S / surprise in the paper
|
|
625
|
+
momentum = update
|
|
608
626
|
|
|
609
627
|
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
|
610
628
|
|
|
611
|
-
update =
|
|
629
|
+
update = self.assoc_scan(1. - decay_factor, update)
|
|
612
630
|
|
|
613
631
|
updates[param_name] = inverse_pack(update)
|
|
614
|
-
|
|
632
|
+
|
|
633
|
+
if has_momentum:
|
|
634
|
+
next_momentum[param_name] = inverse_pack(momentum)
|
|
615
635
|
|
|
616
636
|
# compute the next weight per batch
|
|
617
637
|
|
|
@@ -31,6 +31,7 @@ NUM_PERSIST_MEM = 4
|
|
|
31
31
|
NUM_LONGTERM_MEM = 4
|
|
32
32
|
NEURAL_MEM_LAYERS = (2, 4)
|
|
33
33
|
NEURAL_MEM_GATE_ATTN_OUTPUT = True
|
|
34
|
+
NEURAL_MEM_MOMENTUM = True
|
|
34
35
|
WINDOW_SIZE = 32
|
|
35
36
|
NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
|
|
36
37
|
SLIDING_WINDOWS = True
|
|
@@ -88,6 +89,7 @@ model = MemoryAsContextTransformer(
|
|
|
88
89
|
dim_head = 64,
|
|
89
90
|
heads = 4,
|
|
90
91
|
attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
|
|
92
|
+
momentum = NEURAL_MEM_MOMENTUM,
|
|
91
93
|
use_accelerated_scan = USE_ACCELERATED_SCAN,
|
|
92
94
|
learned_mem_model_weights = LEARNED_MEM_MODEL_WEIGHTS,
|
|
93
95
|
default_model_kwargs = dict(
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|