titans-pytorch 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,191 @@
1
+ from __future__ import annotations
2
+ import math
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch import nn, cat
7
+ import torch.nn.functional as F
8
+ from torch.nn import Module, ModuleList, Linear
9
+
10
+ from einops import repeat
11
+ from einops.layers.torch import Rearrange
12
+
13
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
14
+
15
+ # constants
16
+
17
+ LinearNoBias = partial(Linear, bias = False)
18
+
19
+ # helpers
20
+
21
+ def exists(v):
22
+ return v is not None
23
+
24
+ def default(v, d):
25
+ return v if exists(v) else d
26
+
27
+ def round_up_multiple(seq, mult):
28
+ return math.ceil(seq / mult) * mult
29
+
30
+ # feedforward and attention
31
+
32
+ class GEGLU(Module):
33
+ def forward(self, x):
34
+ x, gate = x.chunk(2, dim = -1)
35
+ return F.silu(gate) * x
36
+
37
+ def FeedForward(dim, mult = 4):
38
+ dim_inner = int(dim * mult * 2 / 3)
39
+
40
+ return nn.Sequential(
41
+ nn.RMSNorm(dim),
42
+ nn.Linear(dim, dim_inner * 2),
43
+ GEGLU(),
44
+ nn.Linear(dim_inner, dim)
45
+ )
46
+
47
+ class SegmentedAttention(Module):
48
+ def __init__(
49
+ self,
50
+ dim,
51
+ segment_len,
52
+ num_persist_mem_tokens,
53
+ dim_head = 64,
54
+ heads = 8,
55
+ ):
56
+ super().__init__()
57
+ self.norm = nn.RMSNorm(dim)
58
+
59
+ dim_inner = dim_head * heads
60
+
61
+ self.to_qkv = LinearNoBias(dim, dim_inner * 3)
62
+ self.to_out = LinearNoBias(dim_inner, dim)
63
+
64
+ self.segment_len = segment_len
65
+
66
+ self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
67
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
68
+
69
+ self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
70
+ self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
71
+
72
+ self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
73
+
74
+ def forward(self, seq):
75
+ batch, seq_len = seq.shape[:2]
76
+
77
+ # auto pad to multiple
78
+ # todo - get rid of logic with flex attention
79
+
80
+ need_segment = seq_len >= self.segment_len
81
+
82
+ if need_segment:
83
+ next_seq_len = round_up_multiple(seq_len, self.segment_len)
84
+ padding = next_seq_len - seq_len
85
+
86
+ if padding > 0:
87
+ seq = F.pad(seq, (0, 0, 0, padding))
88
+
89
+ seq = self.segment_seq(seq)
90
+
91
+ # attention
92
+
93
+ seq = self.norm(seq)
94
+
95
+ q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
96
+ q, k, v = map(self.split_heads, (q, k, v))
97
+
98
+ # take care of persistent memory key / values
99
+
100
+ pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
101
+
102
+ k = cat((pmk, k), dim = -2)
103
+ v = cat((pmv, v), dim = -2)
104
+
105
+ # sdpa
106
+
107
+ out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
108
+
109
+ out = self.merge_heads(out)
110
+
111
+ out = self.to_out(out)
112
+
113
+ if need_segment:
114
+ out = self.merge_seq_back(out)
115
+
116
+ return out[:, :seq_len]
117
+
118
+ # MAC transformer
119
+
120
+ class MemoryAsContextTransformer(Module):
121
+ def __init__(
122
+ self,
123
+ *,
124
+ num_tokens,
125
+ dim,
126
+ depth,
127
+ segment_len,
128
+ num_persist_mem_tokens,
129
+ dim_head = 64,
130
+ heads = 8,
131
+ ff_mult = 4,
132
+ num_residual_streams = 4
133
+ ):
134
+ super().__init__()
135
+
136
+ self.token_emb = nn.Embedding(num_tokens, dim)
137
+
138
+ init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
139
+
140
+ self.layers = ModuleList([])
141
+
142
+ for _ in range(depth):
143
+ attn = SegmentedAttention(
144
+ dim = dim,
145
+ dim_head = dim_head,
146
+ heads = heads,
147
+ segment_len = segment_len,
148
+ num_persist_mem_tokens = num_persist_mem_tokens
149
+ )
150
+
151
+ ff = FeedForward(dim = dim, mult = ff_mult)
152
+
153
+ self.layers.append(ModuleList([
154
+ init_hyper_conn(dim = dim, branch = attn),
155
+ init_hyper_conn(dim = dim, branch = ff)
156
+ ]))
157
+
158
+ self.norm = nn.RMSNorm(dim)
159
+
160
+ self.to_logits = LinearNoBias(dim, num_tokens)
161
+
162
+ def forward(self, x):
163
+
164
+ x = self.token_emb(x)
165
+
166
+ x = self.expand_streams(x)
167
+
168
+ for attn, ff in self.layers:
169
+ x = attn(x)
170
+ x = ff(x)
171
+
172
+ x = self.reduce_streams(x)
173
+
174
+ x = self.norm(x)
175
+
176
+ return self.to_logits(x)
177
+
178
+ # main
179
+
180
+ if __name__ == '__main__':
181
+ transformer = MemoryAsContextTransformer(
182
+ num_tokens = 256,
183
+ dim = 256,
184
+ depth = 2,
185
+ num_persist_mem_tokens = 16,
186
+ segment_len = 128,
187
+ )
188
+
189
+ x = torch.randint(0, 256, (1, 1023))
190
+
191
+ logits = transformer(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.25
3
+ Version: 0.0.27
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -37,6 +37,7 @@ Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: einops>=0.8.0
39
39
  Requires-Dist: einx>=0.3.0
40
+ Requires-Dist: hyper-connections>=0.1.8
40
41
  Requires-Dist: ninja
41
42
  Requires-Dist: tensordict
42
43
  Requires-Dist: torch>=2.2
@@ -1,8 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=4xUSfGDdVsR-WmeXX7yXoFfybROvNCjOxL_EHDJ_Wlk,4681
3
4
  titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
4
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
- titans_pytorch-0.0.25.dist-info/METADATA,sha256=NeRAldZl9fN7bc3YzL44kDzH2rsq5SMBZ-7RjGs_B0g,3811
6
- titans_pytorch-0.0.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.25.dist-info/RECORD,,
6
+ titans_pytorch-0.0.27.dist-info/METADATA,sha256=qJp1IDbphEUfW7EyNvQ7RfmHuvB7SH5h_tlnCVwV4EY,3851
7
+ titans_pytorch-0.0.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.27.dist-info/RECORD,,