titans-pytorch 0.0.27__py3-none-any.whl → 0.0.30__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

@@ -1,5 +1,5 @@
1
1
  from __future__ import annotations
2
- import math
2
+ from math import ceil
3
3
  from functools import partial
4
4
 
5
5
  import torch
@@ -7,11 +7,16 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
- from einops import repeat
10
+ from einops import repeat, rearrange
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
14
14
 
15
+ # absolute and relative positions
16
+
17
+ from axial_positional_embedding import ContinuousAxialPositionalEmbedding
18
+ from rotary_embedding_torch import RotaryEmbedding
19
+
15
20
  # constants
16
21
 
17
22
  LinearNoBias = partial(Linear, bias = False)
@@ -25,7 +30,7 @@ def default(v, d):
25
30
  return v if exists(v) else d
26
31
 
27
32
  def round_up_multiple(seq, mult):
28
- return math.ceil(seq / mult) * mult
33
+ return ceil(seq / mult) * mult
29
34
 
30
35
  # feedforward and attention
31
36
 
@@ -49,7 +54,8 @@ class SegmentedAttention(Module):
49
54
  self,
50
55
  dim,
51
56
  segment_len,
52
- num_persist_mem_tokens,
57
+ num_persist_mem_tokens = 0,
58
+ num_longterm_mem_tokens = 0,
53
59
  dim_head = 64,
54
60
  heads = 8,
55
61
  ):
@@ -58,29 +64,37 @@ class SegmentedAttention(Module):
58
64
 
59
65
  dim_inner = dim_head * heads
60
66
 
67
+ self.rotary_emb = RotaryEmbedding(dim_head)
68
+
61
69
  self.to_qkv = LinearNoBias(dim, dim_inner * 3)
62
70
  self.to_out = LinearNoBias(dim_inner, dim)
63
71
 
64
72
  self.segment_len = segment_len
73
+ self.num_longterm_mem_tokens = num_longterm_mem_tokens
74
+
75
+ total_segment_len = segment_len + num_longterm_mem_tokens
65
76
 
66
77
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
67
78
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
68
79
 
69
- self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
70
- self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
80
+ self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = total_segment_len)
81
+ self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = total_segment_len)
71
82
 
72
83
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
73
84
 
74
85
  def forward(self, seq):
86
+ segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
87
+ total_segment_len = segment_len + num_longterm_mem_tokens
88
+
75
89
  batch, seq_len = seq.shape[:2]
76
90
 
77
91
  # auto pad to multiple
78
92
  # todo - get rid of logic with flex attention
79
93
 
80
- need_segment = seq_len >= self.segment_len
94
+ need_segment = seq_len >= total_segment_len
81
95
 
82
96
  if need_segment:
83
- next_seq_len = round_up_multiple(seq_len, self.segment_len)
97
+ next_seq_len = round_up_multiple(seq_len, total_segment_len)
84
98
  padding = next_seq_len - seq_len
85
99
 
86
100
  if padding > 0:
@@ -99,6 +113,12 @@ class SegmentedAttention(Module):
99
113
 
100
114
  pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
101
115
 
116
+ # relative positions
117
+
118
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
119
+
120
+ # persistent memory
121
+
102
122
  k = cat((pmk, k), dim = -2)
103
123
  v = cat((pmv, v), dim = -2)
104
124
 
@@ -125,7 +145,8 @@ class MemoryAsContextTransformer(Module):
125
145
  dim,
126
146
  depth,
127
147
  segment_len,
128
- num_persist_mem_tokens,
148
+ num_longterm_mem_tokens = 0,
149
+ num_persist_mem_tokens = 0,
129
150
  dim_head = 64,
130
151
  heads = 8,
131
152
  ff_mult = 4,
@@ -135,6 +156,17 @@ class MemoryAsContextTransformer(Module):
135
156
 
136
157
  self.token_emb = nn.Embedding(num_tokens, dim)
137
158
 
159
+ self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
160
+
161
+ # long term mem tokens
162
+
163
+ self.segment_len = segment_len
164
+ self.num_longterm_mem_tokens = num_longterm_mem_tokens
165
+
166
+ self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
167
+
168
+ # hyper conection
169
+
138
170
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
139
171
 
140
172
  self.layers = ModuleList([])
@@ -145,6 +177,7 @@ class MemoryAsContextTransformer(Module):
145
177
  dim_head = dim_head,
146
178
  heads = heads,
147
179
  segment_len = segment_len,
180
+ num_longterm_mem_tokens = num_longterm_mem_tokens,
148
181
  num_persist_mem_tokens = num_persist_mem_tokens
149
182
  )
150
183
 
@@ -161,8 +194,45 @@ class MemoryAsContextTransformer(Module):
161
194
 
162
195
  def forward(self, x):
163
196
 
197
+ # math
198
+
199
+ batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
200
+
201
+ windows = ceil(seq_len / segment_len)
202
+ total_segment_len = segment_len + num_longterm_mem_tokens
203
+
204
+ # token embedding
205
+
164
206
  x = self.token_emb(x)
165
207
 
208
+ # intersperse longterm memory
209
+
210
+ need_segment = seq_len >= segment_len
211
+
212
+ if need_segment:
213
+ next_seq_len = round_up_multiple(seq_len, segment_len)
214
+ padding = next_seq_len - seq_len
215
+
216
+ if padding > 0:
217
+ x = F.pad(x, (0, 0, 0, padding))
218
+
219
+ x = rearrange(x, 'b (w n) d -> (b w) n d', n = segment_len)
220
+
221
+ mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
222
+ x = torch.cat((mems, x), dim = -2)
223
+
224
+ if need_segment:
225
+ x = rearrange(x, '(b w) n d -> b (w n) d', b = batch)
226
+ x = x[:, :seq_len]
227
+
228
+ # apply axial positional embedding
229
+ # so intra and inter segment can be more easily discerned by the network
230
+
231
+ pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
232
+ x = x + pos_emb[:seq_len]
233
+
234
+ # expand and reduce streams for hyper connections
235
+
166
236
  x = self.expand_streams(x)
167
237
 
168
238
  for attn, ff in self.layers:
@@ -171,21 +241,8 @@ class MemoryAsContextTransformer(Module):
171
241
 
172
242
  x = self.reduce_streams(x)
173
243
 
244
+ # to logits
245
+
174
246
  x = self.norm(x)
175
247
 
176
248
  return self.to_logits(x)
177
-
178
- # main
179
-
180
- if __name__ == '__main__':
181
- transformer = MemoryAsContextTransformer(
182
- num_tokens = 256,
183
- dim = 256,
184
- depth = 2,
185
- num_persist_mem_tokens = 16,
186
- segment_len = 128,
187
- )
188
-
189
- x = torch.randint(0, 256, (1, 1023))
190
-
191
- logits = transformer(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.27
3
+ Version: 0.0.30
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -35,10 +35,12 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
+ Requires-Dist: axial-positional-embedding>=0.3.5
38
39
  Requires-Dist: einops>=0.8.0
39
40
  Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: hyper-connections>=0.1.8
41
42
  Requires-Dist: ninja
43
+ Requires-Dist: rotary-embedding-torch
42
44
  Requires-Dist: tensordict
43
45
  Requires-Dist: torch>=2.2
44
46
  Provides-Extra: examples
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=4xUSfGDdVsR-WmeXX7yXoFfybROvNCjOxL_EHDJ_Wlk,4681
3
+ titans_pytorch/mac_transformer.py,sha256=XDLc9NWXpVMza03XjU0lkw5lRvtJ25ReTPKNoGslOOk,6773
4
4
  titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.27.dist-info/METADATA,sha256=qJp1IDbphEUfW7EyNvQ7RfmHuvB7SH5h_tlnCVwV4EY,3851
7
- titans_pytorch-0.0.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.27.dist-info/RECORD,,
6
+ titans_pytorch-0.0.30.dist-info/METADATA,sha256=7H6WPsgfBE9ByEUP7r6C-cfBX6K13yNCaVhZMzwUvf8,3938
7
+ titans_pytorch-0.0.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.30.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.30.dist-info/RECORD,,