titans-pytorch 0.0.51__tar.gz → 0.0.53__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.51
3
+ Version: 0.0.53
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -37,7 +37,6 @@ Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: axial-positional-embedding>=0.3.5
39
39
  Requires-Dist: einops>=0.8.0
40
- Requires-Dist: einx>=0.3.0
41
40
  Requires-Dist: hyper-connections>=0.1.8
42
41
  Requires-Dist: ninja
43
42
  Requires-Dist: rotary-embedding-torch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.51"
3
+ version = "0.0.53"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -27,7 +27,6 @@ classifiers=[
27
27
  dependencies = [
28
28
  "accelerated-scan>=0.2.0",
29
29
  "axial_positional_embedding>=0.3.5",
30
- "einx>=0.3.0",
31
30
  "einops>=0.8.0",
32
31
  "hyper-connections>=0.1.8",
33
32
  "Ninja",
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+ from typing import Callable
2
3
  from math import ceil
3
4
  from functools import partial
4
5
 
@@ -32,7 +33,7 @@ def create_mac_block_mask(seq_len, window_size, persist_mem_len):
32
33
 
33
34
  # einstein notation related
34
35
 
35
- from einops import einsum, repeat, rearrange, pack, unpack
36
+ from einops import repeat, rearrange, pack, unpack
36
37
  from einops.layers.torch import Rearrange
37
38
 
38
39
  # b - batch
@@ -128,6 +129,7 @@ class SegmentedAttention(Module):
128
129
  heads = 8,
129
130
  accept_value_residual = False,
130
131
  attend_kwargs: dict = dict(),
132
+ use_flex_attn = False
131
133
  ):
132
134
  super().__init__()
133
135
  self.norm = nn.RMSNorm(dim)
@@ -157,11 +159,79 @@ class SegmentedAttention(Module):
157
159
 
158
160
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
159
161
 
162
+ # flex attn related
163
+
164
+ assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
165
+ self.use_flex_attn = use_flex_attn
166
+
167
+ self.segment_len = segment_len
168
+ self.num_persist_mem_tokens = num_persist_mem_tokens
169
+
170
+ def forward_flex(
171
+ self,
172
+ seq,
173
+ value_residual = None,
174
+ flex_attn_fn: Callable | None = None
175
+ ):
176
+
177
+ assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
178
+
179
+ batch, seq_len = seq.shape[:2]
180
+
181
+ # attention
182
+
183
+ seq = self.norm(seq)
184
+
185
+ q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
186
+ q, k, v = map(self.split_heads, (q, k, v))
187
+
188
+ # value residual
189
+
190
+ orig_v = v
191
+
192
+ if exists(self.to_learned_v_mix):
193
+ mix = self.to_learned_v_mix(seq)
194
+ v = v.lerp(value_residual, mix)
195
+
196
+ # take care of persistent memory key / values
197
+
198
+ pmk, pmv = repeat(self.persistent_memory, 'kv h n d -> kv b h n d', b = batch)
199
+
200
+ # relative positions
201
+
202
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
203
+
204
+ # persistent memory
205
+
206
+ k = cat((pmk, k), dim = -2)
207
+ v = cat((pmv, v), dim = -2)
208
+
209
+ # prep flex attention
210
+
211
+ if not exists(flex_attn_fn):
212
+ block_mask = create_mac_block_mask(seq_len, self.segment_len, self.num_persist_mem_tokens)
213
+
214
+ flex_attn_fn = partial(flex_attention, block_mask = block_mask)
215
+
216
+ # attention
217
+
218
+ out = flex_attn_fn(q, k, v)
219
+
220
+ out = self.merge_heads(out)
221
+
222
+ out = self.to_out(out)
223
+
224
+ return out, orig_v
225
+
160
226
  def forward(
161
227
  self,
162
228
  seq,
163
- value_residual = None
229
+ value_residual = None,
230
+ flex_attn_fn: Callable | None = None
164
231
  ):
232
+ if seq.is_cuda and self.use_flex_attn:
233
+ return self.forward_flex(seq, value_residual, flex_attn_fn)
234
+
165
235
  assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
166
236
 
167
237
  segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
@@ -191,7 +261,7 @@ class SegmentedAttention(Module):
191
261
 
192
262
  # take care of persistent memory key / values
193
263
 
194
- pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
264
+ pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = seq.shape[0])
195
265
 
196
266
  # relative positions
197
267
 
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+ from typing import Callable
2
3
  import math
3
4
  from functools import partial
4
5
 
@@ -16,7 +17,6 @@ from titans_pytorch.associative_scan import (
16
17
  pad_at_dim
17
18
  )
18
19
 
19
- import einx
20
20
  from einops import rearrange, repeat, pack, unpack
21
21
  from einops.layers.torch import Rearrange, Reduce
22
22
 
@@ -338,9 +338,9 @@ class NeuralMemory(Module):
338
338
 
339
339
  # take care of chunking
340
340
 
341
- keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
341
+ keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = chunk_size) for t in (keys, values))
342
342
 
343
- adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = self.chunk_size)
343
+ adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
344
344
 
345
345
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
346
346
 
File without changes