titans-pytorch 0.0.24__tar.gz → 0.0.26__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.24
3
+ Version: 0.0.26
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -37,6 +37,7 @@ Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: einops>=0.8.0
39
39
  Requires-Dist: einx>=0.3.0
40
+ Requires-Dist: hyper-connections>=0.1.8
40
41
  Requires-Dist: ninja
41
42
  Requires-Dist: tensordict
42
43
  Requires-Dist: torch>=2.2
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.24"
3
+ version = "0.0.26"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -28,6 +28,7 @@ dependencies = [
28
28
  "accelerated-scan>=0.2.0",
29
29
  "einx>=0.3.0",
30
30
  "einops>=0.8.0",
31
+ "hyper-connections>=0.1.8",
31
32
  "Ninja",
32
33
  "tensordict",
33
34
  "torch>=2.2",
@@ -0,0 +1,170 @@
1
+ from __future__ import annotations
2
+ import math
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from torch.nn import Module, ModuleList, Linear
9
+
10
+ from einops.layers.torch import Rearrange
11
+
12
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
13
+
14
+ # constants
15
+
16
+ LinearNoBias = partial(Linear, bias = False)
17
+
18
+ # helpers
19
+
20
+ def exists(v):
21
+ return v is not None
22
+
23
+ def default(v, d):
24
+ return v if exists(v) else d
25
+
26
+ def round_up_multiple(seq, mult):
27
+ return math.ceil(seq / mult) * mult
28
+
29
+ # feedforward and attention
30
+
31
+ class GEGLU(Module):
32
+ def forward(self, x):
33
+ x, gate = x.chunk(2, dim = -1)
34
+ return F.silu(gate) * x
35
+
36
+ def FeedForward(dim, mult = 4):
37
+ dim_inner = int(dim * mult * 2 / 3)
38
+
39
+ return nn.Sequential(
40
+ nn.RMSNorm(dim),
41
+ nn.Linear(dim, dim_inner * 2),
42
+ GEGLU(),
43
+ nn.Linear(dim_inner, dim)
44
+ )
45
+
46
+ class SegmentedAttention(Module):
47
+ def __init__(
48
+ self,
49
+ dim,
50
+ segment_len,
51
+ dim_head = 64,
52
+ heads = 8,
53
+ ):
54
+ super().__init__()
55
+ self.norm = nn.RMSNorm(dim)
56
+
57
+ dim_inner = dim_head * heads
58
+
59
+ self.to_qkv = LinearNoBias(dim, dim_inner * 3)
60
+ self.to_out = LinearNoBias(dim_inner, dim)
61
+
62
+ self.segment_len = segment_len
63
+
64
+ self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
65
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
66
+
67
+ self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
68
+ self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
69
+
70
+
71
+ def forward(self, seq):
72
+ batch, seq_len = seq.shape[:2]
73
+
74
+ # auto pad to multiple
75
+ # todo - get rid of logic with flex attention
76
+
77
+ need_segment = seq_len >= self.segment_len
78
+
79
+ if need_segment:
80
+ next_seq_len = round_up_multiple(seq_len, self.segment_len)
81
+ padding = next_seq_len - seq_len
82
+
83
+ if padding > 0:
84
+ seq = F.pad(seq, (0, 0, 0, padding))
85
+
86
+ seq = self.segment_seq(seq)
87
+
88
+ # attention
89
+
90
+ seq = self.norm(seq)
91
+
92
+ q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
93
+ q, k, v = map(self.split_heads, (q, k, v))
94
+
95
+ out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
96
+
97
+ out = self.merge_heads(out)
98
+
99
+ out = self.to_out(out)
100
+
101
+ if need_segment:
102
+ out = self.merge_seq_back(out)
103
+
104
+ return out[:, :seq_len]
105
+
106
+ # MAC transformer
107
+
108
+ class MemoryAsContextTransformer(Module):
109
+ def __init__(
110
+ self,
111
+ *,
112
+ num_tokens,
113
+ dim,
114
+ depth,
115
+ segment_len,
116
+ dim_head = 64,
117
+ heads = 8,
118
+ ff_mult = 4,
119
+ num_residual_streams = 4
120
+ ):
121
+ super().__init__()
122
+
123
+ self.token_emb = nn.Embedding(num_tokens, dim)
124
+
125
+ init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
126
+
127
+ self.layers = ModuleList([])
128
+
129
+ for _ in range(depth):
130
+ attn = SegmentedAttention(dim = dim, dim_head = dim_head, heads = heads, segment_len = segment_len)
131
+ ff = FeedForward(dim = dim, mult = ff_mult)
132
+
133
+ self.layers.append(ModuleList([
134
+ init_hyper_conn(dim = dim, branch = attn),
135
+ init_hyper_conn(dim = dim, branch = ff)
136
+ ]))
137
+
138
+ self.norm = nn.RMSNorm(dim)
139
+
140
+ self.to_logits = LinearNoBias(dim, num_tokens)
141
+
142
+ def forward(self, x):
143
+
144
+ x = self.token_emb(x)
145
+
146
+ x = self.expand_streams(x)
147
+
148
+ for attn, ff in self.layers:
149
+ x = attn(x)
150
+ x = ff(x)
151
+
152
+ x = self.reduce_streams(x)
153
+
154
+ x = self.norm(x)
155
+
156
+ return self.to_logits(x)
157
+
158
+ # main
159
+
160
+ if __name__ == '__main__':
161
+ transformer = MemoryAsContextTransformer(
162
+ num_tokens = 256,
163
+ dim = 256,
164
+ depth = 2,
165
+ segment_len = 128,
166
+ )
167
+
168
+ x = torch.randint(0, 256, (1, 1023))
169
+
170
+ logits = transformer(x)
File without changes
@@ -136,6 +136,7 @@ class NeuralMemory(Module):
136
136
  )
137
137
  ):
138
138
  super().__init__()
139
+ dim_head = default(dim_head, dim)
139
140
 
140
141
  # norms
141
142
 
@@ -146,7 +147,6 @@ class NeuralMemory(Module):
146
147
 
147
148
  # maybe multi-headed
148
149
 
149
- dim_head = default(dim_head, dim)
150
150
  dim_inner = dim_head * heads
151
151
 
152
152
  self.heads = heads