titans-pytorch 0.0.27__py3-none-any.whl → 0.0.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
1
  from __future__ import annotations
2
- import math
2
+ from math import ceil
3
3
  from functools import partial
4
4
 
5
5
  import torch
@@ -7,11 +7,16 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
- from einops import repeat
10
+ from einops import repeat, rearrange
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
14
14
 
15
+ # absolute and relative positions
16
+
17
+ from axial_positional_embedding import ContinuousAxialPositionalEmbedding
18
+ from rotary_embedding_torch import RotaryEmbedding
19
+
15
20
  # constants
16
21
 
17
22
  LinearNoBias = partial(Linear, bias = False)
@@ -25,7 +30,7 @@ def default(v, d):
25
30
  return v if exists(v) else d
26
31
 
27
32
  def round_up_multiple(seq, mult):
28
- return math.ceil(seq / mult) * mult
33
+ return ceil(seq / mult) * mult
29
34
 
30
35
  # feedforward and attention
31
36
 
@@ -49,7 +54,8 @@ class SegmentedAttention(Module):
49
54
  self,
50
55
  dim,
51
56
  segment_len,
52
- num_persist_mem_tokens,
57
+ num_persist_mem_tokens = 0,
58
+ num_longterm_mem_tokens = 0,
53
59
  dim_head = 64,
54
60
  heads = 8,
55
61
  ):
@@ -58,29 +64,37 @@ class SegmentedAttention(Module):
58
64
 
59
65
  dim_inner = dim_head * heads
60
66
 
67
+ self.rotary_emb = RotaryEmbedding(dim_head)
68
+
61
69
  self.to_qkv = LinearNoBias(dim, dim_inner * 3)
62
70
  self.to_out = LinearNoBias(dim_inner, dim)
63
71
 
64
72
  self.segment_len = segment_len
73
+ self.num_longterm_mem_tokens = num_longterm_mem_tokens
74
+
75
+ total_segment_len = segment_len + num_longterm_mem_tokens
65
76
 
66
77
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
67
78
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
68
79
 
69
- self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
70
- self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
80
+ self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = total_segment_len)
81
+ self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = total_segment_len)
71
82
 
72
83
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
73
84
 
74
85
  def forward(self, seq):
86
+ segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
87
+ total_segment_len = segment_len + num_longterm_mem_tokens
88
+
75
89
  batch, seq_len = seq.shape[:2]
76
90
 
77
91
  # auto pad to multiple
78
92
  # todo - get rid of logic with flex attention
79
93
 
80
- need_segment = seq_len >= self.segment_len
94
+ need_segment = seq_len >= total_segment_len
81
95
 
82
96
  if need_segment:
83
- next_seq_len = round_up_multiple(seq_len, self.segment_len)
97
+ next_seq_len = round_up_multiple(seq_len, total_segment_len)
84
98
  padding = next_seq_len - seq_len
85
99
 
86
100
  if padding > 0:
@@ -99,6 +113,12 @@ class SegmentedAttention(Module):
99
113
 
100
114
  pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
101
115
 
116
+ # relative positions
117
+
118
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
119
+
120
+ # persistent memory
121
+
102
122
  k = cat((pmk, k), dim = -2)
103
123
  v = cat((pmv, v), dim = -2)
104
124
 
@@ -125,7 +145,8 @@ class MemoryAsContextTransformer(Module):
125
145
  dim,
126
146
  depth,
127
147
  segment_len,
128
- num_persist_mem_tokens,
148
+ num_longterm_mem_tokens = 0,
149
+ num_persist_mem_tokens = 0,
129
150
  dim_head = 64,
130
151
  heads = 8,
131
152
  ff_mult = 4,
@@ -135,6 +156,17 @@ class MemoryAsContextTransformer(Module):
135
156
 
136
157
  self.token_emb = nn.Embedding(num_tokens, dim)
137
158
 
159
+ self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
160
+
161
+ # long term mem tokens
162
+
163
+ self.segment_len = segment_len
164
+ self.num_longterm_mem_tokens = num_longterm_mem_tokens
165
+
166
+ self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
167
+
168
+ # hyper conection
169
+
138
170
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
139
171
 
140
172
  self.layers = ModuleList([])
@@ -145,6 +177,7 @@ class MemoryAsContextTransformer(Module):
145
177
  dim_head = dim_head,
146
178
  heads = heads,
147
179
  segment_len = segment_len,
180
+ num_longterm_mem_tokens = num_longterm_mem_tokens,
148
181
  num_persist_mem_tokens = num_persist_mem_tokens
149
182
  )
150
183
 
@@ -161,8 +194,45 @@ class MemoryAsContextTransformer(Module):
161
194
 
162
195
  def forward(self, x):
163
196
 
197
+ # math
198
+
199
+ batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
200
+
201
+ windows = ceil(seq_len / segment_len)
202
+ total_segment_len = segment_len + num_longterm_mem_tokens
203
+
204
+ # token embedding
205
+
164
206
  x = self.token_emb(x)
165
207
 
208
+ # intersperse longterm memory
209
+
210
+ need_segment = seq_len >= segment_len
211
+
212
+ if need_segment:
213
+ next_seq_len = round_up_multiple(seq_len, segment_len)
214
+ padding = next_seq_len - seq_len
215
+
216
+ if padding > 0:
217
+ x = F.pad(x, (0, 0, 0, padding))
218
+
219
+ x = rearrange(x, 'b (w n) d -> (b w) n d', n = segment_len)
220
+
221
+ mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
222
+ x = torch.cat((mems, x), dim = -2)
223
+
224
+ if need_segment:
225
+ x = rearrange(x, '(b w) n d -> b (w n) d', b = batch)
226
+ x = x[:, :seq_len]
227
+
228
+ # apply axial positional embedding
229
+ # so intra and inter segment can be more easily discerned by the network
230
+
231
+ pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
232
+ x = x + pos_emb[:seq_len]
233
+
234
+ # expand and reduce streams for hyper connections
235
+
166
236
  x = self.expand_streams(x)
167
237
 
168
238
  for attn, ff in self.layers:
@@ -171,21 +241,8 @@ class MemoryAsContextTransformer(Module):
171
241
 
172
242
  x = self.reduce_streams(x)
173
243
 
244
+ # to logits
245
+
174
246
  x = self.norm(x)
175
247
 
176
248
  return self.to_logits(x)
177
-
178
- # main
179
-
180
- if __name__ == '__main__':
181
- transformer = MemoryAsContextTransformer(
182
- num_tokens = 256,
183
- dim = 256,
184
- depth = 2,
185
- num_persist_mem_tokens = 16,
186
- segment_len = 128,
187
- )
188
-
189
- x = torch.randint(0, 256, (1, 1023))
190
-
191
- logits = transformer(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.27
3
+ Version: 0.0.30
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -35,10 +35,12 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
+ Requires-Dist: axial-positional-embedding>=0.3.5
38
39
  Requires-Dist: einops>=0.8.0
39
40
  Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: hyper-connections>=0.1.8
41
42
  Requires-Dist: ninja
43
+ Requires-Dist: rotary-embedding-torch
42
44
  Requires-Dist: tensordict
43
45
  Requires-Dist: torch>=2.2
44
46
  Provides-Extra: examples
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=4xUSfGDdVsR-WmeXX7yXoFfybROvNCjOxL_EHDJ_Wlk,4681
3
+ titans_pytorch/mac_transformer.py,sha256=XDLc9NWXpVMza03XjU0lkw5lRvtJ25ReTPKNoGslOOk,6773
4
4
  titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.27.dist-info/METADATA,sha256=qJp1IDbphEUfW7EyNvQ7RfmHuvB7SH5h_tlnCVwV4EY,3851
7
- titans_pytorch-0.0.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.27.dist-info/RECORD,,
6
+ titans_pytorch-0.0.30.dist-info/METADATA,sha256=7H6WPsgfBE9ByEUP7r6C-cfBX6K13yNCaVhZMzwUvf8,3938
7
+ titans_pytorch-0.0.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.30.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.30.dist-info/RECORD,,