titans-pytorch 0.0.27__tar.gz → 0.0.29__tar.gz

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.27
3
+ Version: 0.0.29
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -35,10 +35,12 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
+ Requires-Dist: axial-positional-embedding>=0.3.5
38
39
  Requires-Dist: einops>=0.8.0
39
40
  Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: hyper-connections>=0.1.8
41
42
  Requires-Dist: ninja
43
+ Requires-Dist: rotary-embedding-torch
42
44
  Requires-Dist: tensordict
43
45
  Requires-Dist: torch>=2.2
44
46
  Provides-Extra: examples
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.27"
3
+ version = "0.0.29"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -26,10 +26,12 @@ classifiers=[
26
26
 
27
27
  dependencies = [
28
28
  "accelerated-scan>=0.2.0",
29
+ "axial_positional_embedding>=0.3.5",
29
30
  "einx>=0.3.0",
30
31
  "einops>=0.8.0",
31
32
  "hyper-connections>=0.1.8",
32
33
  "Ninja",
34
+ "rotary-embedding-torch",
33
35
  "tensordict",
34
36
  "torch>=2.2",
35
37
  ]
@@ -33,3 +33,19 @@ def test_titans_attn_memory():
33
33
  retrieved = mem(seq)
34
34
 
35
35
  assert seq.shape == retrieved.shape
36
+
37
+ def test_mac():
38
+ from titans_pytorch.mac_transformer import MemoryAsContextTransformer
39
+
40
+ transformer = MemoryAsContextTransformer(
41
+ num_tokens = 256,
42
+ dim = 256,
43
+ depth = 2,
44
+ num_persist_mem_tokens = 16,
45
+ segment_len = 128,
46
+ )
47
+
48
+ x = torch.randint(0, 256, (1, 1023))
49
+
50
+ logits = transformer(x)
51
+ assert logits.shape == (1, 1023, 256)
@@ -1,5 +1,5 @@
1
1
  from __future__ import annotations
2
- import math
2
+ from math import ceil
3
3
  from functools import partial
4
4
 
5
5
  import torch
@@ -10,8 +10,14 @@ from torch.nn import Module, ModuleList, Linear
10
10
  from einops import repeat
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
+
13
14
  from hyper_connections import get_init_and_expand_reduce_stream_functions
14
15
 
16
+ # absolute and relative positions
17
+
18
+ from axial_positional_embedding import ContinuousAxialPositionalEmbedding
19
+ from rotary_embedding_torch import RotaryEmbedding
20
+
15
21
  # constants
16
22
 
17
23
  LinearNoBias = partial(Linear, bias = False)
@@ -25,7 +31,7 @@ def default(v, d):
25
31
  return v if exists(v) else d
26
32
 
27
33
  def round_up_multiple(seq, mult):
28
- return math.ceil(seq / mult) * mult
34
+ return ceil(seq / mult) * mult
29
35
 
30
36
  # feedforward and attention
31
37
 
@@ -58,6 +64,8 @@ class SegmentedAttention(Module):
58
64
 
59
65
  dim_inner = dim_head * heads
60
66
 
67
+ self.rotary_emb = RotaryEmbedding(dim_head)
68
+
61
69
  self.to_qkv = LinearNoBias(dim, dim_inner * 3)
62
70
  self.to_out = LinearNoBias(dim_inner, dim)
63
71
 
@@ -99,6 +107,12 @@ class SegmentedAttention(Module):
99
107
 
100
108
  pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
101
109
 
110
+ # relative positions
111
+
112
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
113
+
114
+ # persistent memory
115
+
102
116
  k = cat((pmk, k), dim = -2)
103
117
  v = cat((pmv, v), dim = -2)
104
118
 
@@ -133,6 +147,9 @@ class MemoryAsContextTransformer(Module):
133
147
  ):
134
148
  super().__init__()
135
149
 
150
+ self.segment_len = segment_len
151
+ self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
152
+
136
153
  self.token_emb = nn.Embedding(num_tokens, dim)
137
154
 
138
155
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
@@ -160,9 +177,19 @@ class MemoryAsContextTransformer(Module):
160
177
  self.to_logits = LinearNoBias(dim, num_tokens)
161
178
 
162
179
  def forward(self, x):
180
+ seq_len, segment_len = x.shape[-1], self.segment_len
181
+ windows = ceil(seq_len / segment_len)
163
182
 
164
183
  x = self.token_emb(x)
165
184
 
185
+ # apply axial positional embedding
186
+ # so intra and inter segment can be more easily discerned by the network
187
+
188
+ pos_emb = self.axial_pos_emb((windows, segment_len), flatten = True)
189
+ x = x + pos_emb[:seq_len]
190
+
191
+ # expand and reduce streams for hyper connections
192
+
166
193
  x = self.expand_streams(x)
167
194
 
168
195
  for attn, ff in self.layers:
File without changes