titans-pytorch 0.0.38__py3-none-any.whl → 0.0.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,7 +7,7 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
- from einops import repeat, rearrange, pack, unpack
10
+ from einops import einsum, repeat, rearrange, pack, unpack
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
@@ -16,6 +16,7 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
16
16
 
17
17
  from axial_positional_embedding import ContinuousAxialPositionalEmbedding
18
18
  from rotary_embedding_torch import RotaryEmbedding
19
+ from x_transformers.attend import Attend
19
20
 
20
21
  # proposed neural memory
21
22
 
@@ -93,6 +94,7 @@ class SegmentedAttention(Module):
93
94
  num_longterm_mem_tokens = 0,
94
95
  dim_head = 64,
95
96
  heads = 8,
97
+ attend_kwargs: dict = dict()
96
98
  ):
97
99
  super().__init__()
98
100
  self.norm = nn.RMSNorm(dim)
@@ -101,6 +103,8 @@ class SegmentedAttention(Module):
101
103
 
102
104
  self.rotary_emb = RotaryEmbedding(dim_head)
103
105
 
106
+ self.attend = Attend(causal = True, **attend_kwargs)
107
+
104
108
  self.to_qkv = LinearNoBias(dim, dim_inner * 3)
105
109
  self.to_out = LinearNoBias(dim_inner, dim)
106
110
 
@@ -145,9 +149,9 @@ class SegmentedAttention(Module):
145
149
  k = cat((pmk, k), dim = -2)
146
150
  v = cat((pmv, v), dim = -2)
147
151
 
148
- # sdpa
152
+ # attention
149
153
 
150
- out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
154
+ out, _ = self.attend(q, k, v)
151
155
 
152
156
  out = self.merge_heads(out)
153
157
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.38
3
+ Version: 0.0.39
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -43,6 +43,7 @@ Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
45
45
  Requires-Dist: torch>=2.2
46
+ Requires-Dist: x-transformers
46
47
  Provides-Extra: examples
47
48
  Requires-Dist: local-attention>=1.10.1; extra == 'examples'
48
49
  Requires-Dist: taylor-series-linear-attention; extra == 'examples'
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=5koIfEulJ841FNrs6URZfW2dp9LMuHzMkaySDrlbuP0,8393
3
+ titans_pytorch/mac_transformer.py,sha256=h58sHfufxMnSXZXyWuW-KBwzq8xwBYmFjU2XtOjUixk,8512
4
4
  titans_pytorch/titans.py,sha256=bv2Ceq-_4nNb5FNx4hLd2inC93m5MmJxO2-Mbf6PKK4,14378
5
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.38.dist-info/METADATA,sha256=L6tEQTEOXCeAU_BuRLbwUO0-gmnbJE-WQNAZ83BNCWA,3938
7
- titans_pytorch-0.0.38.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.38.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.38.dist-info/RECORD,,
6
+ titans_pytorch-0.0.39.dist-info/METADATA,sha256=3KD2hmJ-uOyQ87Z3VB6JfaKtDcLBnoKA8037DpzJuPE,3968
7
+ titans_pytorch-0.0.39.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.39.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.39.dist-info/RECORD,,