titans-pytorch 0.4.6__py3-none-any.whl → 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/neural_memory.py +1 -1
- {titans_pytorch-0.4.6.dist-info → titans_pytorch-0.4.7.dist-info}/METADATA +2 -2
- titans_pytorch-0.4.7.dist-info/RECORD +8 -0
- titans_pytorch/associative_scan.py +0 -178
- titans_pytorch-0.4.6.dist-info/RECORD +0 -9
- {titans_pytorch-0.4.6.dist-info → titans_pytorch-0.4.7.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.4.6.dist-info → titans_pytorch-0.4.7.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/neural_memory.py
CHANGED
@@ -15,7 +15,7 @@ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
15
15
|
|
16
16
|
from tensordict import TensorDict
|
17
17
|
|
18
|
-
from
|
18
|
+
from assoc_scan import AssocScan
|
19
19
|
|
20
20
|
from titans_pytorch.memory_models import(
|
21
21
|
MemoryMLP,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.4.
|
3
|
+
Version: 0.4.7
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -34,7 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
34
34
|
Classifier: Programming Language :: Python :: 3.9
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
36
|
Requires-Python: >=3.9
|
37
|
-
Requires-Dist:
|
37
|
+
Requires-Dist: assoc-scan
|
38
38
|
Requires-Dist: axial-positional-embedding>=0.3.10
|
39
39
|
Requires-Dist: einops>=0.8.0
|
40
40
|
Requires-Dist: einx>=0.3.0
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=sVTOuRUkaIYabFExdLY6s1qXm1UwHHz_J19H8ZV-X74,338
|
2
|
+
titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
|
3
|
+
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
4
|
+
titans_pytorch/neural_memory.py,sha256=EhHptv-9q3PUTJwX9kKAdYMfWueM-JB_kZ3SmRoAdjM,33356
|
5
|
+
titans_pytorch-0.4.7.dist-info/METADATA,sha256=MP0qHzoAM0AZuWg0gL2VOnmpx9HXdHwo5xx2CL0ugso,6797
|
6
|
+
titans_pytorch-0.4.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.4.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.4.7.dist-info/RECORD,,
|
@@ -1,178 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
import math
|
3
|
-
from typing import Callable
|
4
|
-
|
5
|
-
import torch
|
6
|
-
from torch import Tensor
|
7
|
-
from torch.nn import Module
|
8
|
-
import torch.nn.functional as F
|
9
|
-
|
10
|
-
from einops import rearrange, repeat, reduce, pack, unpack
|
11
|
-
|
12
|
-
# taken from S5-pytorch repository
|
13
|
-
# https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
|
14
|
-
|
15
|
-
# helper functions
|
16
|
-
|
17
|
-
def exists(v):
|
18
|
-
return v is not None
|
19
|
-
|
20
|
-
def default(*args):
|
21
|
-
for arg in args:
|
22
|
-
if exists(arg):
|
23
|
-
return arg
|
24
|
-
return None
|
25
|
-
|
26
|
-
def pad_at_dim(t, pad, dim = -1, value = 0.):
|
27
|
-
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
28
|
-
zeros = ((0, 0) * dims_from_right)
|
29
|
-
return F.pad(t, (*zeros, *pad), value = value)
|
30
|
-
|
31
|
-
def pack_one_with_inverse(t, pattern):
|
32
|
-
packed, packed_shape = pack([t], pattern)
|
33
|
-
|
34
|
-
def inverse(out, inv_pattern = None):
|
35
|
-
inv_pattern = default(inv_pattern, pattern)
|
36
|
-
return unpack(out, packed_shape, inv_pattern)[0]
|
37
|
-
|
38
|
-
return packed, inverse
|
39
|
-
|
40
|
-
# the operator that is needed
|
41
|
-
|
42
|
-
@torch.jit.script
|
43
|
-
def binary_operator(
|
44
|
-
a: tuple[Tensor, Tensor],
|
45
|
-
b: tuple[Tensor, Tensor]
|
46
|
-
):
|
47
|
-
a_i, kv_i = a
|
48
|
-
a_j, kv_j = b
|
49
|
-
return a_j * a_i, torch.addcmul(kv_j, a_j, kv_i)
|
50
|
-
|
51
|
-
# Pytorch impl. of jax.lax.associative_scan
|
52
|
-
# made specifically for axis of 1 (sequence of tokens for autoregressive modeling)
|
53
|
-
|
54
|
-
def associative_scan(
|
55
|
-
operator: Callable,
|
56
|
-
elems: tuple[Tensor, Tensor]
|
57
|
-
):
|
58
|
-
num_elems = int(elems[0].shape[1])
|
59
|
-
|
60
|
-
if not all(int(elem.shape[1]) == num_elems for elem in elems[1:]):
|
61
|
-
raise ValueError('Array inputs to associative_scan must have the same '
|
62
|
-
'first dimension. (saw: {})'
|
63
|
-
.format([elem.shape for elem in elems]))
|
64
|
-
|
65
|
-
def _scan(elems):
|
66
|
-
"""Perform scan on `elems`."""
|
67
|
-
num_elems = elems[0].shape[1]
|
68
|
-
|
69
|
-
if num_elems < 2:
|
70
|
-
return elems
|
71
|
-
|
72
|
-
# Combine adjacent pairs of elements.
|
73
|
-
|
74
|
-
reduced_elems = operator(
|
75
|
-
[elem[:, :-1:2] for elem in elems],
|
76
|
-
[elem[:, 1::2] for elem in elems])
|
77
|
-
|
78
|
-
# Recursively compute scan for partially reduced tensors.
|
79
|
-
|
80
|
-
odd_elems = _scan(reduced_elems)
|
81
|
-
|
82
|
-
if num_elems % 2 == 0:
|
83
|
-
even_elems = operator(
|
84
|
-
[e[:, :-1] for e in odd_elems],
|
85
|
-
[e[:, 2::2] for e in elems])
|
86
|
-
else:
|
87
|
-
even_elems = operator(
|
88
|
-
odd_elems,
|
89
|
-
[e[:, 2::2] for e in elems])
|
90
|
-
|
91
|
-
# The first element of a scan is the same as the first element
|
92
|
-
# of the original `elems`.
|
93
|
-
|
94
|
-
even_elems = [
|
95
|
-
torch.cat([elem[:, :1], result], dim=1)
|
96
|
-
for (elem, result) in zip(elems, even_elems)]
|
97
|
-
|
98
|
-
return list(map(_interleave, even_elems, odd_elems))
|
99
|
-
|
100
|
-
return _scan(elems)
|
101
|
-
|
102
|
-
def _interleave(a, b):
|
103
|
-
a_axis_len, b_axis_len = a.shape[1], b.shape[1]
|
104
|
-
output_axis_len = a_axis_len + b_axis_len
|
105
|
-
|
106
|
-
if (a_axis_len == (b_axis_len + 1)):
|
107
|
-
b = pad_at_dim(b, (0, 1), dim = 1)
|
108
|
-
|
109
|
-
stacked = torch.stack([a, b], dim=2)
|
110
|
-
interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
|
111
|
-
|
112
|
-
return interleaved[:, :output_axis_len]
|
113
|
-
|
114
|
-
# associative scan wrapper around naive and accelerated version
|
115
|
-
|
116
|
-
class AssocScan(Module):
|
117
|
-
def __init__(
|
118
|
-
self,
|
119
|
-
use_accelerated = False
|
120
|
-
):
|
121
|
-
super().__init__()
|
122
|
-
self.use_accelerated = use_accelerated
|
123
|
-
|
124
|
-
def forward(
|
125
|
-
self,
|
126
|
-
gates,
|
127
|
-
inputs,
|
128
|
-
prev = None,
|
129
|
-
remove_prev = None
|
130
|
-
):
|
131
|
-
remove_prev = default(remove_prev, exists(prev))
|
132
|
-
|
133
|
-
inputs, inverse_pack_weight_shape = pack_one_with_inverse(inputs, 'b n *')
|
134
|
-
gates, _ = pack_one_with_inverse(gates, 'b n *')
|
135
|
-
|
136
|
-
if exists(prev):
|
137
|
-
prev, _ = pack_one_with_inverse(prev, 'b *')
|
138
|
-
|
139
|
-
if exists(prev):
|
140
|
-
inputs, _ = pack([prev, inputs], 'b * d')
|
141
|
-
gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
|
142
|
-
|
143
|
-
if not self.use_accelerated:
|
144
|
-
_, out = associative_scan(binary_operator, (gates, inputs))
|
145
|
-
|
146
|
-
if remove_prev:
|
147
|
-
out = out[:, 1:]
|
148
|
-
|
149
|
-
return inverse_pack_weight_shape(out)
|
150
|
-
|
151
|
-
from accelerated_scan.triton import scan as triton_scan
|
152
|
-
from accelerated_scan.warp import scan as warp_scan
|
153
|
-
|
154
|
-
scan = triton_scan if gates.is_cuda else warp_scan
|
155
|
-
|
156
|
-
def accelerate_scan_fn(gates, inputs):
|
157
|
-
gates = gates.expand_as(inputs)
|
158
|
-
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
159
|
-
|
160
|
-
seq_len = gates.shape[-1]
|
161
|
-
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
162
|
-
|
163
|
-
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
164
|
-
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
165
|
-
|
166
|
-
outputs = scan(gates.contiguous(), inputs.contiguous())
|
167
|
-
|
168
|
-
outputs = outputs[..., :seq_len]
|
169
|
-
outputs = rearrange(outputs, 'b d n -> b n d')
|
170
|
-
|
171
|
-
return outputs
|
172
|
-
|
173
|
-
out = accelerate_scan_fn(gates, inputs)
|
174
|
-
|
175
|
-
if remove_prev:
|
176
|
-
out = out[:, 1:]
|
177
|
-
|
178
|
-
return inverse_pack_weight_shape(out)
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=sVTOuRUkaIYabFExdLY6s1qXm1UwHHz_J19H8ZV-X74,338
|
2
|
-
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
|
4
|
-
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
-
titans_pytorch/neural_memory.py,sha256=CsB-Wd3T_N_HluRLjjwB2gPdlLymJmmhuxiRJhQuXkA,33377
|
6
|
-
titans_pytorch-0.4.6.dist-info/METADATA,sha256=aWhQMQrjBLzUPmvtH-LY47r4ayz_ts70gGyDMMyJ6Sc,6810
|
7
|
-
titans_pytorch-0.4.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.4.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.4.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|