titans-pytorch 0.0.45__py3-none-any.whl → 0.0.46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -311,7 +311,7 @@ class MemoryAsContextTransformer(Module):
311
311
 
312
312
  pos_emb = self.axial_pos_emb((neural_mem_windows, neural_mem_segment_len), flatten = True)
313
313
 
314
- pos_emb = pos_emb[:seq_len_with_mem]
314
+ x = x + pos_emb[:seq_len_with_mem]
315
315
 
316
316
  # value residual
317
317
 
@@ -324,7 +324,7 @@ class MemoryAsContextTransformer(Module):
324
324
  for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
325
325
 
326
326
  if exists(maybe_neural_mem):
327
- x = maybe_neural_mem(x, pos_emb = pos_emb)
327
+ x = maybe_neural_mem(x)
328
328
 
329
329
  x, values = attn(x, value_residual = value_residual)
330
330
 
titans_pytorch/titans.py CHANGED
@@ -484,14 +484,10 @@ class NeuralMemory(Module):
484
484
  seq,
485
485
  store_seq = None,
486
486
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
487
- return_next_memories = False,
488
- pos_emb: Tensor | None = None
487
+ return_next_memories = False
489
488
  ):
490
489
  batch, seq_len = seq.shape[:2]
491
490
 
492
- if exists(pos_emb):
493
- seq = seq + pos_emb
494
-
495
491
  if seq_len < self.chunk_size:
496
492
  return self.init_empty_memory_embed(batch, seq_len)
497
493
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.45
3
+ Version: 0.0.46
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=mF8PYAjeAjLas1gkYybgzZX1AVK82A_ps_LY00ofYYs,9565
4
+ titans_pytorch/titans.py,sha256=qxQ8pZCz8GEDhKeJMEaeAEzH66GAGVBNaRdNam_-czg,15260
5
+ titans_pytorch-0.0.46.dist-info/METADATA,sha256=Gg1-_Mmp9u_sJYEvaRt5GzKhhJoTNHjBL3efjSSDLL0,4210
6
+ titans_pytorch-0.0.46.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.46.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.46.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=SFB7sXDt1bYpwt_PVrXM0-1vXKEemBTAfnfboU66A7M,9586
4
- titans_pytorch/titans.py,sha256=7LZIbaavC0bk85UBPzNzZP6YxKeFb0ujZ9k5IU048aI,15360
5
- titans_pytorch-0.0.45.dist-info/METADATA,sha256=EqrDXchEvzFbz1BqSdAB8HkPMjRY3KYyBSu16hbKTUs,4210
6
- titans_pytorch-0.0.45.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.45.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.45.dist-info/RECORD,,