titans-pytorch 0.0.31__py3-none-any.whl → 0.0.34__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

@@ -17,6 +17,10 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
17
17
  from axial_positional_embedding import ContinuousAxialPositionalEmbedding
18
18
  from rotary_embedding_torch import RotaryEmbedding
19
19
 
20
+ # proposed neural memory
21
+
22
+ from titans_pytorch.titans import NeuralMemory
23
+
20
24
  # constants
21
25
 
22
26
  LinearNoBias = partial(Linear, bias = False)
@@ -46,13 +50,20 @@ def pad_and_segment_with_inverse(seq, segment_len):
46
50
  next_seq_len_mult = round_up_multiple(seq_len, segment_len)
47
51
 
48
52
  padding = next_seq_len_mult - seq_len
49
- seq = F.pad(seq, (0, 0, 0, padding))
53
+ needs_pad = padding > 0
54
+
55
+ if needs_pad:
56
+ seq = F.pad(seq, (0, 0, 0, padding))
50
57
 
51
58
  seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
52
59
 
53
60
  def inverse(out):
54
61
  out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
55
- return out[:, :-padding]
62
+
63
+ if needs_pad:
64
+ out = out[:, :-padding]
65
+
66
+ return out
56
67
 
57
68
  return seq, inverse
58
69
 
@@ -161,7 +172,9 @@ class MemoryAsContextTransformer(Module):
161
172
  dim_head = 64,
162
173
  heads = 8,
163
174
  ff_mult = 4,
164
- num_residual_streams = 4
175
+ num_residual_streams = 4,
176
+ neural_memory_kwargs: dict = dict(),
177
+ neural_memory_layers: tuple[int, ...] | None = None,
165
178
  ):
166
179
  super().__init__()
167
180
 
@@ -181,8 +194,25 @@ class MemoryAsContextTransformer(Module):
181
194
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
182
195
 
183
196
  self.layers = ModuleList([])
197
+ self.neural_mem_layers = ModuleList([])
198
+
199
+ layers = tuple(range(1, depth + 1))
200
+ neural_memory_layers = set(default(neural_memory_layers, layers))
201
+
202
+ for layer in layers:
203
+
204
+ # neural memory
205
+
206
+ mem = None
207
+
208
+ if num_longterm_mem_tokens > 0 and layer in neural_memory_layers:
209
+ mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
210
+ mem = init_hyper_conn(dim = dim, branch = mem)
211
+
212
+ self.neural_mem_layers.append(mem)
213
+
214
+ # attention and feedforward
184
215
 
185
- for _ in range(depth):
186
216
  attn = SegmentedAttention(
187
217
  dim = dim,
188
218
  dim_head = dim_head,
@@ -203,7 +233,14 @@ class MemoryAsContextTransformer(Module):
203
233
 
204
234
  self.to_logits = LinearNoBias(dim, num_tokens)
205
235
 
206
- def forward(self, x):
236
+ def forward(
237
+ self,
238
+ x,
239
+ return_loss = False
240
+ ):
241
+
242
+ if return_loss:
243
+ x, labels = x[:, :-1], x[:, 1:]
207
244
 
208
245
  # math
209
246
 
@@ -221,7 +258,7 @@ class MemoryAsContextTransformer(Module):
221
258
  x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
222
259
 
223
260
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
224
- x = torch.cat((mems, x), dim = -2)
261
+ x = cat((mems, x), dim = -2)
225
262
 
226
263
  x = inverse_segment(x)
227
264
 
@@ -235,8 +272,27 @@ class MemoryAsContextTransformer(Module):
235
272
 
236
273
  x = self.expand_streams(x)
237
274
 
238
- for attn, ff in self.layers:
275
+ for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
276
+
277
+ if exists(maybe_neural_mem):
278
+ batch_streams = x.shape[0]
279
+
280
+ x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
281
+
282
+ longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
283
+
284
+ longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
285
+
286
+ longterm_mems = maybe_neural_mem(longterm_mems)
287
+
288
+ longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
289
+
290
+ x = cat((longterm_mems, x), dim = -2)
291
+
292
+ x = inverse_segment(x)
293
+
239
294
  x = attn(x)
295
+
240
296
  x = ff(x)
241
297
 
242
298
  x = self.reduce_streams(x)
@@ -245,7 +301,7 @@ class MemoryAsContextTransformer(Module):
245
301
 
246
302
  x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
247
303
 
248
- x = x[:, self.num_longterm_mem_tokens:]
304
+ x = x[:, num_longterm_mem_tokens:]
249
305
 
250
306
  x = inverse_segment(x)
251
307
 
@@ -253,4 +309,9 @@ class MemoryAsContextTransformer(Module):
253
309
 
254
310
  x = self.norm(x)
255
311
 
256
- return self.to_logits(x)
312
+ logits = self.to_logits(x)
313
+
314
+ if not return_loss:
315
+ return logits
316
+
317
+ return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.31
3
+ Version: 0.0.34
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=bq5RbCgA0GWLFHTrDTIKUSQhkkuCkdjEykOwjfHDs0M,6747
3
+ titans_pytorch/mac_transformer.py,sha256=FGShQHD-dQQdQKKzvNS_jTC_FcikdqO_s3ZKOKfr_9E,8502
4
4
  titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.31.dist-info/METADATA,sha256=bN1fVL2S_vML1oqLIA92tvBhkVvnpQN11fU4e1QVI4s,3938
7
- titans_pytorch-0.0.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.31.dist-info/RECORD,,
6
+ titans_pytorch-0.0.34.dist-info/METADATA,sha256=CNqv_jMqk7yj15IpDn2O3jBdVe4wtrSVkht7mk0wW_E,3938
7
+ titans_pytorch-0.0.34.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.34.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.34.dist-info/RECORD,,