titans-pytorch 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -83,6 +83,14 @@ def identity(t):
83
83
  def round_up_multiple(seq, mult):
84
84
  return ceil(seq / mult) * mult
85
85
 
86
+ def pack_with_inverse(t, pattern):
87
+ packed, packed_shape = pack(t, pattern)
88
+
89
+ def inverse(out, inv_pattern = None):
90
+ return unpack(out, packed_shape, default(inv_pattern, pattern))
91
+
92
+ return packed, inverse
93
+
86
94
  def pad_at_dim(t, pad, dim = -1, value = 0.):
87
95
  dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
88
96
  zeros = ((0, 0) * dims_from_right)
@@ -576,7 +584,7 @@ class MemoryAsContextTransformer(Module):
576
584
  x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
577
585
 
578
586
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
579
- x, mem_ps = pack((x, mems), 'b * d')
587
+ x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
580
588
 
581
589
  x = inverse_segment(x)
582
590
 
@@ -585,11 +593,9 @@ class MemoryAsContextTransformer(Module):
585
593
  # apply axial positional embedding
586
594
  # so intra and inter segment can be more easily discerned by the network
587
595
 
588
- neural_mem_windows = ceil(seq_len_with_mem / neural_mem_segment_len)
589
-
590
- pos_emb = self.axial_pos_emb((neural_mem_windows, neural_mem_segment_len), flatten = True)
596
+ pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,))
591
597
 
592
- x = x + pos_emb[:seq_len_with_mem]
598
+ x = x + pos_emb
593
599
 
594
600
  # prep flex attention
595
601
 
@@ -634,7 +640,7 @@ class MemoryAsContextTransformer(Module):
634
640
 
635
641
  x, inverse_segment = pad_and_segment_with_inverse(x, segment_len + num_longterm_mem_tokens)
636
642
 
637
- x, _ = unpack(x, mem_ps, 'b * d')
643
+ x, _ = inverse_pack_mems(x)
638
644
 
639
645
  x = inverse_segment(x)
640
646
 
titans_pytorch/titans.py CHANGED
@@ -391,6 +391,8 @@ class NeuralMemory(Module):
391
391
 
392
392
  # whether to use averaging of chunks, or attention pooling
393
393
 
394
+ assert not (attn_pool_chunks and chunk_size == 1), '`attn_pool_chunks` cannot be set to True if `chunk_size` is set to 1'
395
+
394
396
  if not attn_pool_chunks:
395
397
  chunk_reduce_module = Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size)
396
398
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.8
3
+ Version: 0.1.10
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -35,7 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
- Requires-Dist: axial-positional-embedding>=0.3.5
38
+ Requires-Dist: axial-positional-embedding>=0.3.6
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: hyper-connections>=0.1.8
@@ -62,7 +62,7 @@ Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytor
62
62
 
63
63
  ## Appreciation
64
64
 
65
- - [@sentialx](https://github.com/sentialx) for sharing his early experimental results with me
65
+ - [Eryk](https://github.com/sentialx) for sharing his early experimental results with me, positive for 2 layer MLP
66
66
 
67
67
  ## Install
68
68
 
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=zxknstaI_Uz47Y8WvZ3S7geJ-TNdqKV5Rvj0Jlw8njs,19271
4
+ titans_pytorch/titans.py,sha256=gZvYk1j6aBMp0uE6l1a2GH_4ea9W2uXKytJb3CDPTlk,21162
5
+ titans_pytorch-0.1.10.dist-info/METADATA,sha256=o2D4Zau9GLBZmsj2qzq7agWckPnBJhDtIeTj2cMgy7Q,4769
6
+ titans_pytorch-0.1.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.10.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
4
- titans_pytorch/titans.py,sha256=qRUw-Lad_dkMqV7ASMNoGLgxYwGD-maAadetAd_qmc8,21031
5
- titans_pytorch-0.1.8.dist-info/METADATA,sha256=0-m6h7GERineU8N9_2cW6nCuXs96twFEwVYkHVuuuLM,4747
6
- titans_pytorch-0.1.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.8.dist-info/RECORD,,