titans-pytorch 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +45 -34
- {titans_pytorch-0.1.14.dist-info → titans_pytorch-0.1.15.dist-info}/METADATA +2 -2
- titans_pytorch-0.1.15.dist-info/RECORD +8 -0
- titans_pytorch-0.1.14.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.14.dist-info → titans_pytorch-0.1.15.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.14.dist-info → titans_pytorch-0.1.15.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
|
@@ -301,6 +301,45 @@ class MemoryAttention(Module):
|
|
|
301
301
|
|
|
302
302
|
return out
|
|
303
303
|
|
|
304
|
+
# associative scan wrapper
|
|
305
|
+
|
|
306
|
+
class AssocScan(Module):
|
|
307
|
+
def __init__(
|
|
308
|
+
self,
|
|
309
|
+
use_accelerated = False
|
|
310
|
+
):
|
|
311
|
+
super().__init__()
|
|
312
|
+
self.use_accelerated = use_accelerated
|
|
313
|
+
|
|
314
|
+
def forward(self, gates, inputs):
|
|
315
|
+
|
|
316
|
+
if not self.use_accelerated:
|
|
317
|
+
_, outputs = associative_scan(binary_operator, (gates, inputs))
|
|
318
|
+
return outputs
|
|
319
|
+
|
|
320
|
+
from accelerated_scan.triton import scan as triton_scan
|
|
321
|
+
from accelerated_scan.warp import scan as warp_scan
|
|
322
|
+
|
|
323
|
+
scan = triton_scan if gates.is_cuda else warp_scan
|
|
324
|
+
|
|
325
|
+
def accelerate_scan_fn(gates, inputs):
|
|
326
|
+
gates = gates.expand_as(inputs)
|
|
327
|
+
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
|
328
|
+
|
|
329
|
+
seq_len = gates.shape[-1]
|
|
330
|
+
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
|
331
|
+
|
|
332
|
+
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
|
333
|
+
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
|
334
|
+
|
|
335
|
+
outputs = scan(gates.contiguous(), inputs.contiguous())
|
|
336
|
+
|
|
337
|
+
outputs = outputs[..., :seq_len]
|
|
338
|
+
outputs = rearrange(outputs, 'b d n -> b n d')
|
|
339
|
+
return outputs
|
|
340
|
+
|
|
341
|
+
return accelerate_scan_fn(gates, inputs)
|
|
342
|
+
|
|
304
343
|
# main neural memory
|
|
305
344
|
|
|
306
345
|
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
|
|
@@ -339,6 +378,10 @@ class NeuralMemory(Module):
|
|
|
339
378
|
|
|
340
379
|
self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
|
|
341
380
|
|
|
381
|
+
# associative scan
|
|
382
|
+
|
|
383
|
+
self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan)
|
|
384
|
+
|
|
342
385
|
# norms
|
|
343
386
|
|
|
344
387
|
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
|
@@ -564,38 +607,6 @@ class NeuralMemory(Module):
|
|
|
564
607
|
|
|
565
608
|
surprises = grads.apply(lambda t: -t)
|
|
566
609
|
|
|
567
|
-
# determine scan function
|
|
568
|
-
|
|
569
|
-
def default_associative_scan(gates, inputs):
|
|
570
|
-
_, outputs = associative_scan(binary_operator, (gates, inputs))
|
|
571
|
-
return outputs
|
|
572
|
-
|
|
573
|
-
if self.use_accelerated_scan:
|
|
574
|
-
from accelerated_scan.triton import scan as triton_scan
|
|
575
|
-
from accelerated_scan.warp import scan as warp_scan
|
|
576
|
-
|
|
577
|
-
scan = triton_scan if seq.is_cuda else warp_scan
|
|
578
|
-
|
|
579
|
-
def accelerate_scan_fn(gates, inputs):
|
|
580
|
-
gates = gates.expand_as(inputs)
|
|
581
|
-
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
|
582
|
-
|
|
583
|
-
seq_len = gates.shape[-1]
|
|
584
|
-
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
|
585
|
-
|
|
586
|
-
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
|
587
|
-
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
|
588
|
-
|
|
589
|
-
outputs = scan(gates.contiguous(), inputs.contiguous())
|
|
590
|
-
|
|
591
|
-
outputs = outputs[..., :seq_len]
|
|
592
|
-
outputs = rearrange(outputs, 'b d n -> b n d')
|
|
593
|
-
return outputs
|
|
594
|
-
|
|
595
|
-
scan_fn = accelerate_scan_fn
|
|
596
|
-
else:
|
|
597
|
-
scan_fn = default_associative_scan
|
|
598
|
-
|
|
599
610
|
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
|
600
611
|
|
|
601
612
|
next_momentum = TensorDict() if has_momentum else None
|
|
@@ -610,12 +621,12 @@ class NeuralMemory(Module):
|
|
|
610
621
|
# derive momentum with associative scan - eq (10)
|
|
611
622
|
|
|
612
623
|
if has_momentum:
|
|
613
|
-
update =
|
|
624
|
+
update = self.assoc_scan(adaptive_momentum, surprise) # momentum is S / surprise in the paper
|
|
614
625
|
momentum = update
|
|
615
626
|
|
|
616
627
|
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
|
617
628
|
|
|
618
|
-
update =
|
|
629
|
+
update = self.assoc_scan(1. - decay_factor, update)
|
|
619
630
|
|
|
620
631
|
updates[param_name] = inverse_pack(update)
|
|
621
632
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.15
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -35,7 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
|
38
|
-
Requires-Dist: axial-positional-embedding>=0.3.
|
|
38
|
+
Requires-Dist: axial-positional-embedding>=0.3.9
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
41
|
Requires-Dist: hyper-connections>=0.1.8
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=zxknstaI_Uz47Y8WvZ3S7geJ-TNdqKV5Rvj0Jlw8njs,19271
|
|
4
|
+
titans_pytorch/titans.py,sha256=R0e25ly2uTHkHSZEb-9Eqb0DqtFq8wFBB8iH1T6bYVg,22240
|
|
5
|
+
titans_pytorch-0.1.15.dist-info/METADATA,sha256=SnNsoK4obeOAWFPhQypYJfJWZ_abXKr7WCvLMqFdyg0,6340
|
|
6
|
+
titans_pytorch-0.1.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.1.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.1.15.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=zxknstaI_Uz47Y8WvZ3S7geJ-TNdqKV5Rvj0Jlw8njs,19271
|
|
4
|
-
titans_pytorch/titans.py,sha256=J7UbhhL0YKoQjrKPCgwUYwJd8-YfCZrHTSVjIFvvRRw,22077
|
|
5
|
-
titans_pytorch-0.1.14.dist-info/METADATA,sha256=lM_x9obTibKKMXXdsW0U5GcQLTR4aXCtemEKPzYPqYo,6340
|
|
6
|
-
titans_pytorch-0.1.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.1.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.1.14.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|