titans-pytorch 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +191 -0
- {titans_pytorch-0.0.25.dist-info → titans_pytorch-0.0.27.dist-info}/METADATA +2 -1
- {titans_pytorch-0.0.25.dist-info → titans_pytorch-0.0.27.dist-info}/RECORD +5 -4
- {titans_pytorch-0.0.25.dist-info → titans_pytorch-0.0.27.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.25.dist-info → titans_pytorch-0.0.27.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,191 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import math
|
3
|
+
from functools import partial
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch import nn, cat
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from torch.nn import Module, ModuleList, Linear
|
9
|
+
|
10
|
+
from einops import repeat
|
11
|
+
from einops.layers.torch import Rearrange
|
12
|
+
|
13
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
14
|
+
|
15
|
+
# constants
|
16
|
+
|
17
|
+
LinearNoBias = partial(Linear, bias = False)
|
18
|
+
|
19
|
+
# helpers
|
20
|
+
|
21
|
+
def exists(v):
|
22
|
+
return v is not None
|
23
|
+
|
24
|
+
def default(v, d):
|
25
|
+
return v if exists(v) else d
|
26
|
+
|
27
|
+
def round_up_multiple(seq, mult):
|
28
|
+
return math.ceil(seq / mult) * mult
|
29
|
+
|
30
|
+
# feedforward and attention
|
31
|
+
|
32
|
+
class GEGLU(Module):
|
33
|
+
def forward(self, x):
|
34
|
+
x, gate = x.chunk(2, dim = -1)
|
35
|
+
return F.silu(gate) * x
|
36
|
+
|
37
|
+
def FeedForward(dim, mult = 4):
|
38
|
+
dim_inner = int(dim * mult * 2 / 3)
|
39
|
+
|
40
|
+
return nn.Sequential(
|
41
|
+
nn.RMSNorm(dim),
|
42
|
+
nn.Linear(dim, dim_inner * 2),
|
43
|
+
GEGLU(),
|
44
|
+
nn.Linear(dim_inner, dim)
|
45
|
+
)
|
46
|
+
|
47
|
+
class SegmentedAttention(Module):
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
dim,
|
51
|
+
segment_len,
|
52
|
+
num_persist_mem_tokens,
|
53
|
+
dim_head = 64,
|
54
|
+
heads = 8,
|
55
|
+
):
|
56
|
+
super().__init__()
|
57
|
+
self.norm = nn.RMSNorm(dim)
|
58
|
+
|
59
|
+
dim_inner = dim_head * heads
|
60
|
+
|
61
|
+
self.to_qkv = LinearNoBias(dim, dim_inner * 3)
|
62
|
+
self.to_out = LinearNoBias(dim_inner, dim)
|
63
|
+
|
64
|
+
self.segment_len = segment_len
|
65
|
+
|
66
|
+
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
67
|
+
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
68
|
+
|
69
|
+
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
|
70
|
+
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
|
71
|
+
|
72
|
+
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
73
|
+
|
74
|
+
def forward(self, seq):
|
75
|
+
batch, seq_len = seq.shape[:2]
|
76
|
+
|
77
|
+
# auto pad to multiple
|
78
|
+
# todo - get rid of logic with flex attention
|
79
|
+
|
80
|
+
need_segment = seq_len >= self.segment_len
|
81
|
+
|
82
|
+
if need_segment:
|
83
|
+
next_seq_len = round_up_multiple(seq_len, self.segment_len)
|
84
|
+
padding = next_seq_len - seq_len
|
85
|
+
|
86
|
+
if padding > 0:
|
87
|
+
seq = F.pad(seq, (0, 0, 0, padding))
|
88
|
+
|
89
|
+
seq = self.segment_seq(seq)
|
90
|
+
|
91
|
+
# attention
|
92
|
+
|
93
|
+
seq = self.norm(seq)
|
94
|
+
|
95
|
+
q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
|
96
|
+
q, k, v = map(self.split_heads, (q, k, v))
|
97
|
+
|
98
|
+
# take care of persistent memory key / values
|
99
|
+
|
100
|
+
pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
|
101
|
+
|
102
|
+
k = cat((pmk, k), dim = -2)
|
103
|
+
v = cat((pmv, v), dim = -2)
|
104
|
+
|
105
|
+
# sdpa
|
106
|
+
|
107
|
+
out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
|
108
|
+
|
109
|
+
out = self.merge_heads(out)
|
110
|
+
|
111
|
+
out = self.to_out(out)
|
112
|
+
|
113
|
+
if need_segment:
|
114
|
+
out = self.merge_seq_back(out)
|
115
|
+
|
116
|
+
return out[:, :seq_len]
|
117
|
+
|
118
|
+
# MAC transformer
|
119
|
+
|
120
|
+
class MemoryAsContextTransformer(Module):
|
121
|
+
def __init__(
|
122
|
+
self,
|
123
|
+
*,
|
124
|
+
num_tokens,
|
125
|
+
dim,
|
126
|
+
depth,
|
127
|
+
segment_len,
|
128
|
+
num_persist_mem_tokens,
|
129
|
+
dim_head = 64,
|
130
|
+
heads = 8,
|
131
|
+
ff_mult = 4,
|
132
|
+
num_residual_streams = 4
|
133
|
+
):
|
134
|
+
super().__init__()
|
135
|
+
|
136
|
+
self.token_emb = nn.Embedding(num_tokens, dim)
|
137
|
+
|
138
|
+
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
139
|
+
|
140
|
+
self.layers = ModuleList([])
|
141
|
+
|
142
|
+
for _ in range(depth):
|
143
|
+
attn = SegmentedAttention(
|
144
|
+
dim = dim,
|
145
|
+
dim_head = dim_head,
|
146
|
+
heads = heads,
|
147
|
+
segment_len = segment_len,
|
148
|
+
num_persist_mem_tokens = num_persist_mem_tokens
|
149
|
+
)
|
150
|
+
|
151
|
+
ff = FeedForward(dim = dim, mult = ff_mult)
|
152
|
+
|
153
|
+
self.layers.append(ModuleList([
|
154
|
+
init_hyper_conn(dim = dim, branch = attn),
|
155
|
+
init_hyper_conn(dim = dim, branch = ff)
|
156
|
+
]))
|
157
|
+
|
158
|
+
self.norm = nn.RMSNorm(dim)
|
159
|
+
|
160
|
+
self.to_logits = LinearNoBias(dim, num_tokens)
|
161
|
+
|
162
|
+
def forward(self, x):
|
163
|
+
|
164
|
+
x = self.token_emb(x)
|
165
|
+
|
166
|
+
x = self.expand_streams(x)
|
167
|
+
|
168
|
+
for attn, ff in self.layers:
|
169
|
+
x = attn(x)
|
170
|
+
x = ff(x)
|
171
|
+
|
172
|
+
x = self.reduce_streams(x)
|
173
|
+
|
174
|
+
x = self.norm(x)
|
175
|
+
|
176
|
+
return self.to_logits(x)
|
177
|
+
|
178
|
+
# main
|
179
|
+
|
180
|
+
if __name__ == '__main__':
|
181
|
+
transformer = MemoryAsContextTransformer(
|
182
|
+
num_tokens = 256,
|
183
|
+
dim = 256,
|
184
|
+
depth = 2,
|
185
|
+
num_persist_mem_tokens = 16,
|
186
|
+
segment_len = 128,
|
187
|
+
)
|
188
|
+
|
189
|
+
x = torch.randint(0, 256, (1, 1023))
|
190
|
+
|
191
|
+
logits = transformer(x)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.27
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -37,6 +37,7 @@ Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
38
38
|
Requires-Dist: einops>=0.8.0
|
39
39
|
Requires-Dist: einx>=0.3.0
|
40
|
+
Requires-Dist: hyper-connections>=0.1.8
|
40
41
|
Requires-Dist: ninja
|
41
42
|
Requires-Dist: tensordict
|
42
43
|
Requires-Dist: torch>=2.2
|
@@ -1,8 +1,9 @@
|
|
1
1
|
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=4xUSfGDdVsR-WmeXX7yXoFfybROvNCjOxL_EHDJ_Wlk,4681
|
3
4
|
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
4
5
|
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
-
titans_pytorch-0.0.
|
6
|
-
titans_pytorch-0.0.
|
7
|
-
titans_pytorch-0.0.
|
8
|
-
titans_pytorch-0.0.
|
6
|
+
titans_pytorch-0.0.27.dist-info/METADATA,sha256=qJp1IDbphEUfW7EyNvQ7RfmHuvB7SH5h_tlnCVwV4EY,3851
|
7
|
+
titans_pytorch-0.0.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.27.dist-info/RECORD,,
|
File without changes
|
File without changes
|