titans-pytorch 0.0.25__py3-none-any.whl → 0.0.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,170 @@
1
+ from __future__ import annotations
2
+ import math
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from torch.nn import Module, ModuleList, Linear
9
+
10
+ from einops.layers.torch import Rearrange
11
+
12
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
13
+
14
+ # constants
15
+
16
+ LinearNoBias = partial(Linear, bias = False)
17
+
18
+ # helpers
19
+
20
+ def exists(v):
21
+ return v is not None
22
+
23
+ def default(v, d):
24
+ return v if exists(v) else d
25
+
26
+ def round_up_multiple(seq, mult):
27
+ return math.ceil(seq / mult) * mult
28
+
29
+ # feedforward and attention
30
+
31
+ class GEGLU(Module):
32
+ def forward(self, x):
33
+ x, gate = x.chunk(2, dim = -1)
34
+ return F.silu(gate) * x
35
+
36
+ def FeedForward(dim, mult = 4):
37
+ dim_inner = int(dim * mult * 2 / 3)
38
+
39
+ return nn.Sequential(
40
+ nn.RMSNorm(dim),
41
+ nn.Linear(dim, dim_inner * 2),
42
+ GEGLU(),
43
+ nn.Linear(dim_inner, dim)
44
+ )
45
+
46
+ class SegmentedAttention(Module):
47
+ def __init__(
48
+ self,
49
+ dim,
50
+ segment_len,
51
+ dim_head = 64,
52
+ heads = 8,
53
+ ):
54
+ super().__init__()
55
+ self.norm = nn.RMSNorm(dim)
56
+
57
+ dim_inner = dim_head * heads
58
+
59
+ self.to_qkv = LinearNoBias(dim, dim_inner * 3)
60
+ self.to_out = LinearNoBias(dim_inner, dim)
61
+
62
+ self.segment_len = segment_len
63
+
64
+ self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
65
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
66
+
67
+ self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
68
+ self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
69
+
70
+
71
+ def forward(self, seq):
72
+ batch, seq_len = seq.shape[:2]
73
+
74
+ # auto pad to multiple
75
+ # todo - get rid of logic with flex attention
76
+
77
+ need_segment = seq_len >= self.segment_len
78
+
79
+ if need_segment:
80
+ next_seq_len = round_up_multiple(seq_len, self.segment_len)
81
+ padding = next_seq_len - seq_len
82
+
83
+ if padding > 0:
84
+ seq = F.pad(seq, (0, 0, 0, padding))
85
+
86
+ seq = self.segment_seq(seq)
87
+
88
+ # attention
89
+
90
+ seq = self.norm(seq)
91
+
92
+ q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
93
+ q, k, v = map(self.split_heads, (q, k, v))
94
+
95
+ out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
96
+
97
+ out = self.merge_heads(out)
98
+
99
+ out = self.to_out(out)
100
+
101
+ if need_segment:
102
+ out = self.merge_seq_back(out)
103
+
104
+ return out[:, :seq_len]
105
+
106
+ # MAC transformer
107
+
108
+ class MemoryAsContextTransformer(Module):
109
+ def __init__(
110
+ self,
111
+ *,
112
+ num_tokens,
113
+ dim,
114
+ depth,
115
+ segment_len,
116
+ dim_head = 64,
117
+ heads = 8,
118
+ ff_mult = 4,
119
+ num_residual_streams = 4
120
+ ):
121
+ super().__init__()
122
+
123
+ self.token_emb = nn.Embedding(num_tokens, dim)
124
+
125
+ init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
126
+
127
+ self.layers = ModuleList([])
128
+
129
+ for _ in range(depth):
130
+ attn = SegmentedAttention(dim = dim, dim_head = dim_head, heads = heads, segment_len = segment_len)
131
+ ff = FeedForward(dim = dim, mult = ff_mult)
132
+
133
+ self.layers.append(ModuleList([
134
+ init_hyper_conn(dim = dim, branch = attn),
135
+ init_hyper_conn(dim = dim, branch = ff)
136
+ ]))
137
+
138
+ self.norm = nn.RMSNorm(dim)
139
+
140
+ self.to_logits = LinearNoBias(dim, num_tokens)
141
+
142
+ def forward(self, x):
143
+
144
+ x = self.token_emb(x)
145
+
146
+ x = self.expand_streams(x)
147
+
148
+ for attn, ff in self.layers:
149
+ x = attn(x)
150
+ x = ff(x)
151
+
152
+ x = self.reduce_streams(x)
153
+
154
+ x = self.norm(x)
155
+
156
+ return self.to_logits(x)
157
+
158
+ # main
159
+
160
+ if __name__ == '__main__':
161
+ transformer = MemoryAsContextTransformer(
162
+ num_tokens = 256,
163
+ dim = 256,
164
+ depth = 2,
165
+ segment_len = 128,
166
+ )
167
+
168
+ x = torch.randint(0, 256, (1, 1023))
169
+
170
+ logits = transformer(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.25
3
+ Version: 0.0.26
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -37,6 +37,7 @@ Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: einops>=0.8.0
39
39
  Requires-Dist: einx>=0.3.0
40
+ Requires-Dist: hyper-connections>=0.1.8
40
41
  Requires-Dist: ninja
41
42
  Requires-Dist: tensordict
42
43
  Requires-Dist: torch>=2.2
@@ -1,8 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=63g4IEXbAClN6_FwcZuBPau0ndQ5yRSr0x7TIZDlxxw,4051
3
4
  titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
4
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
- titans_pytorch-0.0.25.dist-info/METADATA,sha256=NeRAldZl9fN7bc3YzL44kDzH2rsq5SMBZ-7RjGs_B0g,3811
6
- titans_pytorch-0.0.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.25.dist-info/RECORD,,
6
+ titans_pytorch-0.0.26.dist-info/METADATA,sha256=Dh6ymVAQ-EkWJBt0Z_eWV_UNBQfainhsus6ShywOBkk,3851
7
+ titans_pytorch-0.0.26.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.26.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.26.dist-info/RECORD,,