titans-pytorch 0.0.27__py3-none-any.whl → 0.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
1
  from __future__ import annotations
2
- import math
2
+ from math import ceil
3
3
  from functools import partial
4
4
 
5
5
  import torch
@@ -10,8 +10,14 @@ from torch.nn import Module, ModuleList, Linear
10
10
  from einops import repeat
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
+
13
14
  from hyper_connections import get_init_and_expand_reduce_stream_functions
14
15
 
16
+ # absolute and relative positions
17
+
18
+ from axial_positional_embedding import ContinuousAxialPositionalEmbedding
19
+ from rotary_embedding_torch import RotaryEmbedding
20
+
15
21
  # constants
16
22
 
17
23
  LinearNoBias = partial(Linear, bias = False)
@@ -25,7 +31,7 @@ def default(v, d):
25
31
  return v if exists(v) else d
26
32
 
27
33
  def round_up_multiple(seq, mult):
28
- return math.ceil(seq / mult) * mult
34
+ return ceil(seq / mult) * mult
29
35
 
30
36
  # feedforward and attention
31
37
 
@@ -58,6 +64,8 @@ class SegmentedAttention(Module):
58
64
 
59
65
  dim_inner = dim_head * heads
60
66
 
67
+ self.rotary_emb = RotaryEmbedding(dim_head)
68
+
61
69
  self.to_qkv = LinearNoBias(dim, dim_inner * 3)
62
70
  self.to_out = LinearNoBias(dim_inner, dim)
63
71
 
@@ -99,6 +107,12 @@ class SegmentedAttention(Module):
99
107
 
100
108
  pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
101
109
 
110
+ # relative positions
111
+
112
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
113
+
114
+ # persistent memory
115
+
102
116
  k = cat((pmk, k), dim = -2)
103
117
  v = cat((pmv, v), dim = -2)
104
118
 
@@ -133,6 +147,9 @@ class MemoryAsContextTransformer(Module):
133
147
  ):
134
148
  super().__init__()
135
149
 
150
+ self.segment_len = segment_len
151
+ self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
152
+
136
153
  self.token_emb = nn.Embedding(num_tokens, dim)
137
154
 
138
155
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
@@ -160,9 +177,19 @@ class MemoryAsContextTransformer(Module):
160
177
  self.to_logits = LinearNoBias(dim, num_tokens)
161
178
 
162
179
  def forward(self, x):
180
+ seq_len, segment_len = x.shape[-1], self.segment_len
181
+ windows = ceil(seq_len / segment_len)
163
182
 
164
183
  x = self.token_emb(x)
165
184
 
185
+ # apply axial positional embedding
186
+ # so intra and inter segment can be more easily discerned by the network
187
+
188
+ pos_emb = self.axial_pos_emb((windows, segment_len), flatten = True)
189
+ x = x + pos_emb[:seq_len]
190
+
191
+ # expand and reduce streams for hyper connections
192
+
166
193
  x = self.expand_streams(x)
167
194
 
168
195
  for attn, ff in self.layers:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.27
3
+ Version: 0.0.29
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -35,10 +35,12 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
+ Requires-Dist: axial-positional-embedding>=0.3.5
38
39
  Requires-Dist: einops>=0.8.0
39
40
  Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: hyper-connections>=0.1.8
41
42
  Requires-Dist: ninja
43
+ Requires-Dist: rotary-embedding-torch
42
44
  Requires-Dist: tensordict
43
45
  Requires-Dist: torch>=2.2
44
46
  Provides-Extra: examples
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=4xUSfGDdVsR-WmeXX7yXoFfybROvNCjOxL_EHDJ_Wlk,4681
3
+ titans_pytorch/mac_transformer.py,sha256=JBrJah7gfQDPizYRcBvpUKinrd2I9KMB997f3RIR8TA,5568
4
4
  titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.27.dist-info/METADATA,sha256=qJp1IDbphEUfW7EyNvQ7RfmHuvB7SH5h_tlnCVwV4EY,3851
7
- titans_pytorch-0.0.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.27.dist-info/RECORD,,
6
+ titans_pytorch-0.0.29.dist-info/METADATA,sha256=EhS4E9SAoqzDa0PIjZpQmUSYAo5IS-XePofWlZZnIS0,3938
7
+ titans_pytorch-0.0.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.29.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.29.dist-info/RECORD,,