titans-pytorch 0.0.38__py3-none-any.whl → 0.0.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,7 +7,7 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
- from einops import repeat, rearrange, pack, unpack
10
+ from einops import einsum, repeat, rearrange, pack, unpack
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
@@ -16,6 +16,7 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
16
16
 
17
17
  from axial_positional_embedding import ContinuousAxialPositionalEmbedding
18
18
  from rotary_embedding_torch import RotaryEmbedding
19
+ from x_transformers.attend import Attend
19
20
 
20
21
  # proposed neural memory
21
22
 
@@ -93,6 +94,8 @@ class SegmentedAttention(Module):
93
94
  num_longterm_mem_tokens = 0,
94
95
  dim_head = 64,
95
96
  heads = 8,
97
+ accept_value_residual = False,
98
+ attend_kwargs: dict = dict()
96
99
  ):
97
100
  super().__init__()
98
101
  self.norm = nn.RMSNorm(dim)
@@ -101,9 +104,17 @@ class SegmentedAttention(Module):
101
104
 
102
105
  self.rotary_emb = RotaryEmbedding(dim_head)
103
106
 
107
+ self.attend = Attend(causal = True, **attend_kwargs)
108
+
104
109
  self.to_qkv = LinearNoBias(dim, dim_inner * 3)
105
110
  self.to_out = LinearNoBias(dim_inner, dim)
106
111
 
112
+ self.to_learned_v_mix = nn.Sequential(
113
+ nn.Linear(dim, heads),
114
+ Rearrange('b n h -> b h n 1'),
115
+ nn.Sigmoid()
116
+ ) if accept_value_residual else None
117
+
107
118
  self.segment_len = segment_len
108
119
  self.num_longterm_mem_tokens = num_longterm_mem_tokens
109
120
 
@@ -114,7 +125,13 @@ class SegmentedAttention(Module):
114
125
 
115
126
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
116
127
 
117
- def forward(self, seq):
128
+ def forward(
129
+ self,
130
+ seq,
131
+ value_residual = None
132
+ ):
133
+ assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
134
+
118
135
  segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
119
136
  total_segment_len = segment_len + num_longterm_mem_tokens
120
137
 
@@ -132,6 +149,14 @@ class SegmentedAttention(Module):
132
149
  q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
133
150
  q, k, v = map(self.split_heads, (q, k, v))
134
151
 
152
+ # value residual
153
+
154
+ orig_v = v
155
+
156
+ if exists(self.to_learned_v_mix):
157
+ mix = self.to_learned_v_mix(seq)
158
+ v = v.lerp(value_residual, mix)
159
+
135
160
  # take care of persistent memory key / values
136
161
 
137
162
  pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
@@ -145,9 +170,9 @@ class SegmentedAttention(Module):
145
170
  k = cat((pmk, k), dim = -2)
146
171
  v = cat((pmv, v), dim = -2)
147
172
 
148
- # sdpa
173
+ # attention
149
174
 
150
- out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
175
+ out, _ = self.attend(q, k, v)
151
176
 
152
177
  out = self.merge_heads(out)
153
178
 
@@ -155,7 +180,7 @@ class SegmentedAttention(Module):
155
180
 
156
181
  out = inverse_segment(out)
157
182
 
158
- return out
183
+ return out, orig_v
159
184
 
160
185
  # MAC transformer
161
186
 
@@ -206,6 +231,7 @@ class MemoryAsContextTransformer(Module):
206
231
  assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
207
232
 
208
233
  for layer in layers:
234
+ is_first = layer == 1
209
235
 
210
236
  # neural memory
211
237
 
@@ -231,6 +257,7 @@ class MemoryAsContextTransformer(Module):
231
257
  dim_head = dim_head,
232
258
  heads = heads,
233
259
  segment_len = segment_len,
260
+ accept_value_residual = not is_first,
234
261
  num_longterm_mem_tokens = num_longterm_mem_tokens,
235
262
  num_persist_mem_tokens = num_persist_mem_tokens
236
263
  )
@@ -281,6 +308,10 @@ class MemoryAsContextTransformer(Module):
281
308
  pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
282
309
  x = x + pos_emb[:x.shape[-2]]
283
310
 
311
+ # value residual
312
+
313
+ value_residual = None
314
+
284
315
  # expand and reduce streams for hyper connections
285
316
 
286
317
  x = self.expand_streams(x)
@@ -291,7 +322,9 @@ class MemoryAsContextTransformer(Module):
291
322
  x = maybe_neural_mem(x)
292
323
 
293
324
 
294
- x = attn(x)
325
+ x, values = attn(x, value_residual = value_residual)
326
+
327
+ value_residual = default(value_residual, values)
295
328
 
296
329
  x = ff(x)
297
330
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.38
3
+ Version: 0.0.40
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -43,6 +43,7 @@ Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
45
45
  Requires-Dist: torch>=2.2
46
+ Requires-Dist: x-transformers
46
47
  Provides-Extra: examples
47
48
  Requires-Dist: local-attention>=1.10.1; extra == 'examples'
48
49
  Requires-Dist: taylor-series-linear-attention; extra == 'examples'
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=5koIfEulJ841FNrs6URZfW2dp9LMuHzMkaySDrlbuP0,8393
3
+ titans_pytorch/mac_transformer.py,sha256=y9sruSvGCEL4flu_RW7bCdvIe-S9dEdGacbmPYL1kqA,9311
4
4
  titans_pytorch/titans.py,sha256=bv2Ceq-_4nNb5FNx4hLd2inC93m5MmJxO2-Mbf6PKK4,14378
5
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.38.dist-info/METADATA,sha256=L6tEQTEOXCeAU_BuRLbwUO0-gmnbJE-WQNAZ83BNCWA,3938
7
- titans_pytorch-0.0.38.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.38.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.38.dist-info/RECORD,,
6
+ titans_pytorch-0.0.40.dist-info/METADATA,sha256=JCJ5aG9_-rVUErW6u-DXkJtVQ52Bf3XQDN3puirXAXo,3968
7
+ titans_pytorch-0.0.40.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.40.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.40.dist-info/RECORD,,