titans-pytorch 0.0.43__py3-none-any.whl → 0.0.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -224,7 +224,7 @@ class MemoryAsContextTransformer(Module):
224
224
  self.layers = ModuleList([])
225
225
 
226
226
  self.neural_mem_layers = ModuleList([])
227
- neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
227
+ self.neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
228
228
 
229
229
  layers = tuple(range(1, depth + 1))
230
230
 
@@ -245,7 +245,7 @@ class MemoryAsContextTransformer(Module):
245
245
 
246
246
  mem = NeuralMemory(
247
247
  dim = dim,
248
- chunk_size = neural_memory_segment_len,
248
+ chunk_size = self.neural_memory_segment_len,
249
249
  **neural_memory_kwargs
250
250
  )
251
251
 
@@ -287,10 +287,7 @@ class MemoryAsContextTransformer(Module):
287
287
 
288
288
  # math
289
289
 
290
- batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
291
-
292
- windows = ceil(seq_len / segment_len)
293
- total_segment_len = segment_len + num_longterm_mem_tokens
290
+ batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens
294
291
 
295
292
  # token embedding
296
293
 
@@ -305,11 +302,16 @@ class MemoryAsContextTransformer(Module):
305
302
 
306
303
  x = inverse_segment(x)
307
304
 
305
+ seq_len_with_mem = x.shape[-2]
306
+
308
307
  # apply axial positional embedding
309
308
  # so intra and inter segment can be more easily discerned by the network
310
309
 
311
- pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
312
- x = x + pos_emb[:x.shape[-2]]
310
+ neural_mem_windows = ceil(seq_len_with_mem / neural_mem_segment_len)
311
+
312
+ pos_emb = self.axial_pos_emb((neural_mem_windows, neural_mem_segment_len), flatten = True)
313
+
314
+ pos_emb = pos_emb[:seq_len_with_mem]
313
315
 
314
316
  # value residual
315
317
 
@@ -322,7 +324,7 @@ class MemoryAsContextTransformer(Module):
322
324
  for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
323
325
 
324
326
  if exists(maybe_neural_mem):
325
- x = maybe_neural_mem(x)
327
+ x = maybe_neural_mem(x, pos_emb = pos_emb)
326
328
 
327
329
  x, values = attn(x, value_residual = value_residual)
328
330
 
@@ -334,7 +336,7 @@ class MemoryAsContextTransformer(Module):
334
336
 
335
337
  # excise out the memories
336
338
 
337
- x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
339
+ x, inverse_segment = pad_and_segment_with_inverse(x, segment_len + num_longterm_mem_tokens)
338
340
 
339
341
  x, _ = unpack(x, mem_ps, 'b * d')
340
342
 
titans_pytorch/titans.py CHANGED
@@ -484,10 +484,14 @@ class NeuralMemory(Module):
484
484
  seq,
485
485
  store_seq = None,
486
486
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
487
- return_next_memories = False
487
+ return_next_memories = False,
488
+ pos_emb: Tensor | None = None
488
489
  ):
489
490
  batch, seq_len = seq.shape[:2]
490
491
 
492
+ if exists(pos_emb):
493
+ seq = seq + pos_emb
494
+
491
495
  if seq_len < self.chunk_size:
492
496
  return self.init_empty_memory_embed(batch, seq_len)
493
497
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.43
3
+ Version: 0.0.45
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=SFB7sXDt1bYpwt_PVrXM0-1vXKEemBTAfnfboU66A7M,9586
4
+ titans_pytorch/titans.py,sha256=7LZIbaavC0bk85UBPzNzZP6YxKeFb0ujZ9k5IU048aI,15360
5
+ titans_pytorch-0.0.45.dist-info/METADATA,sha256=EqrDXchEvzFbz1BqSdAB8HkPMjRY3KYyBSu16hbKTUs,4210
6
+ titans_pytorch-0.0.45.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.45.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.45.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=kSdfWGWwEk6d0lbb0WLVKQwdmG8LAzDg36QZm7aIio0,9451
4
- titans_pytorch/titans.py,sha256=qxQ8pZCz8GEDhKeJMEaeAEzH66GAGVBNaRdNam_-czg,15260
5
- titans_pytorch-0.0.43.dist-info/METADATA,sha256=3Rlt_5CIeDUkYEK5tcLiWTseWv48gg4OH5vMoSVNS2w,4210
6
- titans_pytorch-0.0.43.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.43.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.43.dist-info/RECORD,,