titans-pytorch 0.0.25__py3-none-any.whl → 0.0.26__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- titans_pytorch/mac_transformer.py +170 -0
- {titans_pytorch-0.0.25.dist-info → titans_pytorch-0.0.26.dist-info}/METADATA +2 -1
- {titans_pytorch-0.0.25.dist-info → titans_pytorch-0.0.26.dist-info}/RECORD +5 -4
- {titans_pytorch-0.0.25.dist-info → titans_pytorch-0.0.26.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.25.dist-info → titans_pytorch-0.0.26.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,170 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import math
|
3
|
+
from functools import partial
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch import nn
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from torch.nn import Module, ModuleList, Linear
|
9
|
+
|
10
|
+
from einops.layers.torch import Rearrange
|
11
|
+
|
12
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
13
|
+
|
14
|
+
# constants
|
15
|
+
|
16
|
+
LinearNoBias = partial(Linear, bias = False)
|
17
|
+
|
18
|
+
# helpers
|
19
|
+
|
20
|
+
def exists(v):
|
21
|
+
return v is not None
|
22
|
+
|
23
|
+
def default(v, d):
|
24
|
+
return v if exists(v) else d
|
25
|
+
|
26
|
+
def round_up_multiple(seq, mult):
|
27
|
+
return math.ceil(seq / mult) * mult
|
28
|
+
|
29
|
+
# feedforward and attention
|
30
|
+
|
31
|
+
class GEGLU(Module):
|
32
|
+
def forward(self, x):
|
33
|
+
x, gate = x.chunk(2, dim = -1)
|
34
|
+
return F.silu(gate) * x
|
35
|
+
|
36
|
+
def FeedForward(dim, mult = 4):
|
37
|
+
dim_inner = int(dim * mult * 2 / 3)
|
38
|
+
|
39
|
+
return nn.Sequential(
|
40
|
+
nn.RMSNorm(dim),
|
41
|
+
nn.Linear(dim, dim_inner * 2),
|
42
|
+
GEGLU(),
|
43
|
+
nn.Linear(dim_inner, dim)
|
44
|
+
)
|
45
|
+
|
46
|
+
class SegmentedAttention(Module):
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
dim,
|
50
|
+
segment_len,
|
51
|
+
dim_head = 64,
|
52
|
+
heads = 8,
|
53
|
+
):
|
54
|
+
super().__init__()
|
55
|
+
self.norm = nn.RMSNorm(dim)
|
56
|
+
|
57
|
+
dim_inner = dim_head * heads
|
58
|
+
|
59
|
+
self.to_qkv = LinearNoBias(dim, dim_inner * 3)
|
60
|
+
self.to_out = LinearNoBias(dim_inner, dim)
|
61
|
+
|
62
|
+
self.segment_len = segment_len
|
63
|
+
|
64
|
+
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
65
|
+
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
66
|
+
|
67
|
+
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
|
68
|
+
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
|
69
|
+
|
70
|
+
|
71
|
+
def forward(self, seq):
|
72
|
+
batch, seq_len = seq.shape[:2]
|
73
|
+
|
74
|
+
# auto pad to multiple
|
75
|
+
# todo - get rid of logic with flex attention
|
76
|
+
|
77
|
+
need_segment = seq_len >= self.segment_len
|
78
|
+
|
79
|
+
if need_segment:
|
80
|
+
next_seq_len = round_up_multiple(seq_len, self.segment_len)
|
81
|
+
padding = next_seq_len - seq_len
|
82
|
+
|
83
|
+
if padding > 0:
|
84
|
+
seq = F.pad(seq, (0, 0, 0, padding))
|
85
|
+
|
86
|
+
seq = self.segment_seq(seq)
|
87
|
+
|
88
|
+
# attention
|
89
|
+
|
90
|
+
seq = self.norm(seq)
|
91
|
+
|
92
|
+
q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
|
93
|
+
q, k, v = map(self.split_heads, (q, k, v))
|
94
|
+
|
95
|
+
out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
|
96
|
+
|
97
|
+
out = self.merge_heads(out)
|
98
|
+
|
99
|
+
out = self.to_out(out)
|
100
|
+
|
101
|
+
if need_segment:
|
102
|
+
out = self.merge_seq_back(out)
|
103
|
+
|
104
|
+
return out[:, :seq_len]
|
105
|
+
|
106
|
+
# MAC transformer
|
107
|
+
|
108
|
+
class MemoryAsContextTransformer(Module):
|
109
|
+
def __init__(
|
110
|
+
self,
|
111
|
+
*,
|
112
|
+
num_tokens,
|
113
|
+
dim,
|
114
|
+
depth,
|
115
|
+
segment_len,
|
116
|
+
dim_head = 64,
|
117
|
+
heads = 8,
|
118
|
+
ff_mult = 4,
|
119
|
+
num_residual_streams = 4
|
120
|
+
):
|
121
|
+
super().__init__()
|
122
|
+
|
123
|
+
self.token_emb = nn.Embedding(num_tokens, dim)
|
124
|
+
|
125
|
+
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
126
|
+
|
127
|
+
self.layers = ModuleList([])
|
128
|
+
|
129
|
+
for _ in range(depth):
|
130
|
+
attn = SegmentedAttention(dim = dim, dim_head = dim_head, heads = heads, segment_len = segment_len)
|
131
|
+
ff = FeedForward(dim = dim, mult = ff_mult)
|
132
|
+
|
133
|
+
self.layers.append(ModuleList([
|
134
|
+
init_hyper_conn(dim = dim, branch = attn),
|
135
|
+
init_hyper_conn(dim = dim, branch = ff)
|
136
|
+
]))
|
137
|
+
|
138
|
+
self.norm = nn.RMSNorm(dim)
|
139
|
+
|
140
|
+
self.to_logits = LinearNoBias(dim, num_tokens)
|
141
|
+
|
142
|
+
def forward(self, x):
|
143
|
+
|
144
|
+
x = self.token_emb(x)
|
145
|
+
|
146
|
+
x = self.expand_streams(x)
|
147
|
+
|
148
|
+
for attn, ff in self.layers:
|
149
|
+
x = attn(x)
|
150
|
+
x = ff(x)
|
151
|
+
|
152
|
+
x = self.reduce_streams(x)
|
153
|
+
|
154
|
+
x = self.norm(x)
|
155
|
+
|
156
|
+
return self.to_logits(x)
|
157
|
+
|
158
|
+
# main
|
159
|
+
|
160
|
+
if __name__ == '__main__':
|
161
|
+
transformer = MemoryAsContextTransformer(
|
162
|
+
num_tokens = 256,
|
163
|
+
dim = 256,
|
164
|
+
depth = 2,
|
165
|
+
segment_len = 128,
|
166
|
+
)
|
167
|
+
|
168
|
+
x = torch.randint(0, 256, (1, 1023))
|
169
|
+
|
170
|
+
logits = transformer(x)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.26
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -37,6 +37,7 @@ Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
38
38
|
Requires-Dist: einops>=0.8.0
|
39
39
|
Requires-Dist: einx>=0.3.0
|
40
|
+
Requires-Dist: hyper-connections>=0.1.8
|
40
41
|
Requires-Dist: ninja
|
41
42
|
Requires-Dist: tensordict
|
42
43
|
Requires-Dist: torch>=2.2
|
@@ -1,8 +1,9 @@
|
|
1
1
|
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=63g4IEXbAClN6_FwcZuBPau0ndQ5yRSr0x7TIZDlxxw,4051
|
3
4
|
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
4
5
|
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
-
titans_pytorch-0.0.
|
6
|
-
titans_pytorch-0.0.
|
7
|
-
titans_pytorch-0.0.
|
8
|
-
titans_pytorch-0.0.
|
6
|
+
titans_pytorch-0.0.26.dist-info/METADATA,sha256=Dh6ymVAQ-EkWJBt0Z_eWV_UNBQfainhsus6ShywOBkk,3851
|
7
|
+
titans_pytorch-0.0.26.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.26.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.26.dist-info/RECORD,,
|
File without changes
|
File without changes
|