titans-pytorch 0.0.27__py3-none-any.whl → 0.0.29__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

@@ -1,5 +1,5 @@
1
1
  from __future__ import annotations
2
- import math
2
+ from math import ceil
3
3
  from functools import partial
4
4
 
5
5
  import torch
@@ -10,8 +10,14 @@ from torch.nn import Module, ModuleList, Linear
10
10
  from einops import repeat
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
+
13
14
  from hyper_connections import get_init_and_expand_reduce_stream_functions
14
15
 
16
+ # absolute and relative positions
17
+
18
+ from axial_positional_embedding import ContinuousAxialPositionalEmbedding
19
+ from rotary_embedding_torch import RotaryEmbedding
20
+
15
21
  # constants
16
22
 
17
23
  LinearNoBias = partial(Linear, bias = False)
@@ -25,7 +31,7 @@ def default(v, d):
25
31
  return v if exists(v) else d
26
32
 
27
33
  def round_up_multiple(seq, mult):
28
- return math.ceil(seq / mult) * mult
34
+ return ceil(seq / mult) * mult
29
35
 
30
36
  # feedforward and attention
31
37
 
@@ -58,6 +64,8 @@ class SegmentedAttention(Module):
58
64
 
59
65
  dim_inner = dim_head * heads
60
66
 
67
+ self.rotary_emb = RotaryEmbedding(dim_head)
68
+
61
69
  self.to_qkv = LinearNoBias(dim, dim_inner * 3)
62
70
  self.to_out = LinearNoBias(dim_inner, dim)
63
71
 
@@ -99,6 +107,12 @@ class SegmentedAttention(Module):
99
107
 
100
108
  pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
101
109
 
110
+ # relative positions
111
+
112
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
113
+
114
+ # persistent memory
115
+
102
116
  k = cat((pmk, k), dim = -2)
103
117
  v = cat((pmv, v), dim = -2)
104
118
 
@@ -133,6 +147,9 @@ class MemoryAsContextTransformer(Module):
133
147
  ):
134
148
  super().__init__()
135
149
 
150
+ self.segment_len = segment_len
151
+ self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
152
+
136
153
  self.token_emb = nn.Embedding(num_tokens, dim)
137
154
 
138
155
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
@@ -160,9 +177,19 @@ class MemoryAsContextTransformer(Module):
160
177
  self.to_logits = LinearNoBias(dim, num_tokens)
161
178
 
162
179
  def forward(self, x):
180
+ seq_len, segment_len = x.shape[-1], self.segment_len
181
+ windows = ceil(seq_len / segment_len)
163
182
 
164
183
  x = self.token_emb(x)
165
184
 
185
+ # apply axial positional embedding
186
+ # so intra and inter segment can be more easily discerned by the network
187
+
188
+ pos_emb = self.axial_pos_emb((windows, segment_len), flatten = True)
189
+ x = x + pos_emb[:seq_len]
190
+
191
+ # expand and reduce streams for hyper connections
192
+
166
193
  x = self.expand_streams(x)
167
194
 
168
195
  for attn, ff in self.layers:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.27
3
+ Version: 0.0.29
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -35,10 +35,12 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
+ Requires-Dist: axial-positional-embedding>=0.3.5
38
39
  Requires-Dist: einops>=0.8.0
39
40
  Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: hyper-connections>=0.1.8
41
42
  Requires-Dist: ninja
43
+ Requires-Dist: rotary-embedding-torch
42
44
  Requires-Dist: tensordict
43
45
  Requires-Dist: torch>=2.2
44
46
  Provides-Extra: examples
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=4xUSfGDdVsR-WmeXX7yXoFfybROvNCjOxL_EHDJ_Wlk,4681
3
+ titans_pytorch/mac_transformer.py,sha256=JBrJah7gfQDPizYRcBvpUKinrd2I9KMB997f3RIR8TA,5568
4
4
  titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.27.dist-info/METADATA,sha256=qJp1IDbphEUfW7EyNvQ7RfmHuvB7SH5h_tlnCVwV4EY,3851
7
- titans_pytorch-0.0.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.27.dist-info/RECORD,,
6
+ titans_pytorch-0.0.29.dist-info/METADATA,sha256=EhS4E9SAoqzDa0PIjZpQmUSYAo5IS-XePofWlZZnIS0,3938
7
+ titans_pytorch-0.0.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.29.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.29.dist-info/RECORD,,