titans-pytorch 0.0.38__py3-none-any.whl → 0.0.39__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

@@ -7,7 +7,7 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
- from einops import repeat, rearrange, pack, unpack
10
+ from einops import einsum, repeat, rearrange, pack, unpack
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
@@ -16,6 +16,7 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
16
16
 
17
17
  from axial_positional_embedding import ContinuousAxialPositionalEmbedding
18
18
  from rotary_embedding_torch import RotaryEmbedding
19
+ from x_transformers.attend import Attend
19
20
 
20
21
  # proposed neural memory
21
22
 
@@ -93,6 +94,7 @@ class SegmentedAttention(Module):
93
94
  num_longterm_mem_tokens = 0,
94
95
  dim_head = 64,
95
96
  heads = 8,
97
+ attend_kwargs: dict = dict()
96
98
  ):
97
99
  super().__init__()
98
100
  self.norm = nn.RMSNorm(dim)
@@ -101,6 +103,8 @@ class SegmentedAttention(Module):
101
103
 
102
104
  self.rotary_emb = RotaryEmbedding(dim_head)
103
105
 
106
+ self.attend = Attend(causal = True, **attend_kwargs)
107
+
104
108
  self.to_qkv = LinearNoBias(dim, dim_inner * 3)
105
109
  self.to_out = LinearNoBias(dim_inner, dim)
106
110
 
@@ -145,9 +149,9 @@ class SegmentedAttention(Module):
145
149
  k = cat((pmk, k), dim = -2)
146
150
  v = cat((pmv, v), dim = -2)
147
151
 
148
- # sdpa
152
+ # attention
149
153
 
150
- out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
154
+ out, _ = self.attend(q, k, v)
151
155
 
152
156
  out = self.merge_heads(out)
153
157
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.38
3
+ Version: 0.0.39
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -43,6 +43,7 @@ Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
45
45
  Requires-Dist: torch>=2.2
46
+ Requires-Dist: x-transformers
46
47
  Provides-Extra: examples
47
48
  Requires-Dist: local-attention>=1.10.1; extra == 'examples'
48
49
  Requires-Dist: taylor-series-linear-attention; extra == 'examples'
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=5koIfEulJ841FNrs6URZfW2dp9LMuHzMkaySDrlbuP0,8393
3
+ titans_pytorch/mac_transformer.py,sha256=h58sHfufxMnSXZXyWuW-KBwzq8xwBYmFjU2XtOjUixk,8512
4
4
  titans_pytorch/titans.py,sha256=bv2Ceq-_4nNb5FNx4hLd2inC93m5MmJxO2-Mbf6PKK4,14378
5
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.38.dist-info/METADATA,sha256=L6tEQTEOXCeAU_BuRLbwUO0-gmnbJE-WQNAZ83BNCWA,3938
7
- titans_pytorch-0.0.38.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.38.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.38.dist-info/RECORD,,
6
+ titans_pytorch-0.0.39.dist-info/METADATA,sha256=3KD2hmJ-uOyQ87Z3VB6JfaKtDcLBnoKA8037DpzJuPE,3968
7
+ titans_pytorch-0.0.39.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.39.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.39.dist-info/RECORD,,