titans-pytorch 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,191 @@
1
+ from __future__ import annotations
2
+ import math
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch import nn, cat
7
+ import torch.nn.functional as F
8
+ from torch.nn import Module, ModuleList, Linear
9
+
10
+ from einops import repeat
11
+ from einops.layers.torch import Rearrange
12
+
13
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
14
+
15
+ # constants
16
+
17
+ LinearNoBias = partial(Linear, bias = False)
18
+
19
+ # helpers
20
+
21
+ def exists(v):
22
+ return v is not None
23
+
24
+ def default(v, d):
25
+ return v if exists(v) else d
26
+
27
+ def round_up_multiple(seq, mult):
28
+ return math.ceil(seq / mult) * mult
29
+
30
+ # feedforward and attention
31
+
32
+ class GEGLU(Module):
33
+ def forward(self, x):
34
+ x, gate = x.chunk(2, dim = -1)
35
+ return F.silu(gate) * x
36
+
37
+ def FeedForward(dim, mult = 4):
38
+ dim_inner = int(dim * mult * 2 / 3)
39
+
40
+ return nn.Sequential(
41
+ nn.RMSNorm(dim),
42
+ nn.Linear(dim, dim_inner * 2),
43
+ GEGLU(),
44
+ nn.Linear(dim_inner, dim)
45
+ )
46
+
47
+ class SegmentedAttention(Module):
48
+ def __init__(
49
+ self,
50
+ dim,
51
+ segment_len,
52
+ num_persist_mem_tokens,
53
+ dim_head = 64,
54
+ heads = 8,
55
+ ):
56
+ super().__init__()
57
+ self.norm = nn.RMSNorm(dim)
58
+
59
+ dim_inner = dim_head * heads
60
+
61
+ self.to_qkv = LinearNoBias(dim, dim_inner * 3)
62
+ self.to_out = LinearNoBias(dim_inner, dim)
63
+
64
+ self.segment_len = segment_len
65
+
66
+ self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
67
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
68
+
69
+ self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
70
+ self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
71
+
72
+ self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
73
+
74
+ def forward(self, seq):
75
+ batch, seq_len = seq.shape[:2]
76
+
77
+ # auto pad to multiple
78
+ # todo - get rid of logic with flex attention
79
+
80
+ need_segment = seq_len >= self.segment_len
81
+
82
+ if need_segment:
83
+ next_seq_len = round_up_multiple(seq_len, self.segment_len)
84
+ padding = next_seq_len - seq_len
85
+
86
+ if padding > 0:
87
+ seq = F.pad(seq, (0, 0, 0, padding))
88
+
89
+ seq = self.segment_seq(seq)
90
+
91
+ # attention
92
+
93
+ seq = self.norm(seq)
94
+
95
+ q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
96
+ q, k, v = map(self.split_heads, (q, k, v))
97
+
98
+ # take care of persistent memory key / values
99
+
100
+ pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
101
+
102
+ k = cat((pmk, k), dim = -2)
103
+ v = cat((pmv, v), dim = -2)
104
+
105
+ # sdpa
106
+
107
+ out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
108
+
109
+ out = self.merge_heads(out)
110
+
111
+ out = self.to_out(out)
112
+
113
+ if need_segment:
114
+ out = self.merge_seq_back(out)
115
+
116
+ return out[:, :seq_len]
117
+
118
+ # MAC transformer
119
+
120
+ class MemoryAsContextTransformer(Module):
121
+ def __init__(
122
+ self,
123
+ *,
124
+ num_tokens,
125
+ dim,
126
+ depth,
127
+ segment_len,
128
+ num_persist_mem_tokens,
129
+ dim_head = 64,
130
+ heads = 8,
131
+ ff_mult = 4,
132
+ num_residual_streams = 4
133
+ ):
134
+ super().__init__()
135
+
136
+ self.token_emb = nn.Embedding(num_tokens, dim)
137
+
138
+ init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
139
+
140
+ self.layers = ModuleList([])
141
+
142
+ for _ in range(depth):
143
+ attn = SegmentedAttention(
144
+ dim = dim,
145
+ dim_head = dim_head,
146
+ heads = heads,
147
+ segment_len = segment_len,
148
+ num_persist_mem_tokens = num_persist_mem_tokens
149
+ )
150
+
151
+ ff = FeedForward(dim = dim, mult = ff_mult)
152
+
153
+ self.layers.append(ModuleList([
154
+ init_hyper_conn(dim = dim, branch = attn),
155
+ init_hyper_conn(dim = dim, branch = ff)
156
+ ]))
157
+
158
+ self.norm = nn.RMSNorm(dim)
159
+
160
+ self.to_logits = LinearNoBias(dim, num_tokens)
161
+
162
+ def forward(self, x):
163
+
164
+ x = self.token_emb(x)
165
+
166
+ x = self.expand_streams(x)
167
+
168
+ for attn, ff in self.layers:
169
+ x = attn(x)
170
+ x = ff(x)
171
+
172
+ x = self.reduce_streams(x)
173
+
174
+ x = self.norm(x)
175
+
176
+ return self.to_logits(x)
177
+
178
+ # main
179
+
180
+ if __name__ == '__main__':
181
+ transformer = MemoryAsContextTransformer(
182
+ num_tokens = 256,
183
+ dim = 256,
184
+ depth = 2,
185
+ num_persist_mem_tokens = 16,
186
+ segment_len = 128,
187
+ )
188
+
189
+ x = torch.randint(0, 256, (1, 1023))
190
+
191
+ logits = transformer(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.25
3
+ Version: 0.0.27
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -37,6 +37,7 @@ Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: einops>=0.8.0
39
39
  Requires-Dist: einx>=0.3.0
40
+ Requires-Dist: hyper-connections>=0.1.8
40
41
  Requires-Dist: ninja
41
42
  Requires-Dist: tensordict
42
43
  Requires-Dist: torch>=2.2
@@ -1,8 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=4xUSfGDdVsR-WmeXX7yXoFfybROvNCjOxL_EHDJ_Wlk,4681
3
4
  titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
4
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
- titans_pytorch-0.0.25.dist-info/METADATA,sha256=NeRAldZl9fN7bc3YzL44kDzH2rsq5SMBZ-7RjGs_B0g,3811
6
- titans_pytorch-0.0.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.25.dist-info/RECORD,,
6
+ titans_pytorch-0.0.27.dist-info/METADATA,sha256=qJp1IDbphEUfW7EyNvQ7RfmHuvB7SH5h_tlnCVwV4EY,3851
7
+ titans_pytorch-0.0.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.27.dist-info/RECORD,,