titans-pytorch 0.0.63__py3-none-any.whl → 0.0.65__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,13 +20,21 @@ try:
20
20
  except ImportError:
21
21
  pass
22
22
 
23
- def create_mac_block_mask(seq_len, window_size, persist_mem_len):
23
+ def create_mac_block_mask(seq_len, window_size, persist_mem_len, sliding = False):
24
24
 
25
- def create_mac_mask(b, h, q_idx, kv_idx):
25
+ def create_mac_mask(_, __, q_idx, kv_idx):
26
26
  is_persist_mem = kv_idx < persist_mem_len
27
- causal_mask = q_idx >= (kv_idx - persist_mem_len)
28
- block_diagonal = (q_idx // window_size) == ((kv_idx - persist_mem_len) // window_size)
29
- return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
27
+ kv_without_mem = kv_idx - persist_mem_len
28
+ causal_mask = q_idx >= kv_without_mem
29
+
30
+ if not sliding:
31
+ block_diagonal = (q_idx // window_size) == (kv_without_mem // window_size)
32
+ causal_mask = causal_mask & block_diagonal
33
+ else:
34
+ sliding_mask = (q_idx - kv_without_mem) <= window_size
35
+ causal_mask = causal_mask & sliding_mask
36
+
37
+ return is_persist_mem | (~is_persist_mem & causal_mask)
30
38
 
31
39
  block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
32
40
  return block_mask
@@ -73,7 +81,12 @@ def identity(t):
73
81
  def round_up_multiple(seq, mult):
74
82
  return ceil(seq / mult) * mult
75
83
 
76
- def pad_and_segment_with_inverse(seq, segment_len):
84
+ def pad_at_dim(t, pad, dim = -1, value = 0.):
85
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
86
+ zeros = ((0, 0) * dims_from_right)
87
+ return F.pad(t, (*zeros, *pad), value = value)
88
+
89
+ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
77
90
  batch, seq_len = seq.shape[:2]
78
91
 
79
92
  need_segment = seq_len >= segment_len
@@ -89,13 +102,15 @@ def pad_and_segment_with_inverse(seq, segment_len):
89
102
  if needs_pad:
90
103
  seq = F.pad(seq, (0, 0, 0, padding))
91
104
 
92
- seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
105
+ if fold_into_batch:
106
+ seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
93
107
 
94
108
  def inverse(out):
95
- out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
109
+ if fold_into_batch:
110
+ out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
96
111
 
97
112
  if needs_pad:
98
- out = out[:, :-padding]
113
+ out = out[..., :-padding, :]
99
114
 
100
115
  return out
101
116
 
@@ -127,6 +142,7 @@ class SegmentedAttention(Module):
127
142
  num_longterm_mem_tokens = 0,
128
143
  dim_head = 64,
129
144
  heads = 8,
145
+ sliding = False,
130
146
  accept_value_residual = False,
131
147
  attend_kwargs: dict = dict(),
132
148
  use_flex_attn = False
@@ -153,6 +169,9 @@ class SegmentedAttention(Module):
153
169
  self.num_longterm_mem_tokens = num_longterm_mem_tokens
154
170
 
155
171
  total_segment_len = segment_len + num_longterm_mem_tokens
172
+ self.total_segment_len = total_segment_len
173
+
174
+ self.sliding = sliding # sliding window attn - doubt their non-sliding results being the best. local attention with overlapping windows is very strong
156
175
 
157
176
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
158
177
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
@@ -209,7 +228,7 @@ class SegmentedAttention(Module):
209
228
  # prep flex attention
210
229
 
211
230
  if not exists(flex_attn_fn):
212
- block_mask = create_mac_block_mask(seq_len, self.segment_len, self.num_persist_mem_tokens)
231
+ block_mask = create_mac_block_mask(seq_len, self.total_segment_len, self.num_persist_mem_tokens, self.sliding)
213
232
 
214
233
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
215
234
 
@@ -241,9 +260,8 @@ class SegmentedAttention(Module):
241
260
  batch, seq_len = seq.shape[:2]
242
261
 
243
262
  # auto pad to multiple
244
- # todo - get rid of logic with flex attention
245
263
 
246
- seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
264
+ seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len, fold_into_batch = False)
247
265
 
248
266
  # attention
249
267
 
@@ -260,14 +278,45 @@ class SegmentedAttention(Module):
260
278
  mix = self.to_learned_v_mix(seq)
261
279
  v = v.lerp(value_residual, mix)
262
280
 
263
- # take care of persistent memory key / values
264
-
265
- pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = seq.shape[0])
266
-
267
281
  # relative positions
268
282
 
269
283
  q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
270
284
 
285
+ # fold
286
+
287
+ q, k, v = tuple(rearrange(t, 'b h (w n) d -> (b w) h n d', n = total_segment_len) for t in (q, k, v))
288
+
289
+ # maybe sliding for cpu
290
+
291
+ attend_kwargs = dict()
292
+
293
+ if self.sliding:
294
+ k, v = tuple(rearrange(t, '(b w) ... -> b w ...', b = batch) for t in (k, v))
295
+ k, v = tuple(pad_at_dim(t, (1, 0), value = 0., dim = 1) for t in (k, v))
296
+ k = cat((k[:, :-1], k[:, 1:]), dim = -2)
297
+ v = cat((v[:, :-1], v[:, 1:]), dim = -2)
298
+ k, v = tuple(rearrange(t, 'b w ... -> (b w) ...') for t in (k, v))
299
+
300
+ # take care of masking
301
+
302
+ idx = torch.arange(seq.shape[-2], device = seq.device)
303
+ q_idx = rearrange(idx, '(w n) -> w n', n = total_segment_len)
304
+ k_idx = pad_at_dim(q_idx, (1, 0), dim = 0, value = -1e4)
305
+ k_idx = cat((k_idx[:-1], k_idx[1:]), dim = -1)
306
+
307
+ q_idx = rearrange(q_idx, 'w i -> w i 1')
308
+ k_idx = rearrange(k_idx, 'w j -> w 1 j')
309
+
310
+ sliding_mask = (q_idx - k_idx) <= total_segment_len
311
+ sliding_mask = F.pad(sliding_mask, (self.num_persist_mem_tokens, 0), value = True)
312
+
313
+ sliding_mask = repeat(sliding_mask, 'w i j -> (b w) 1 i j', b = batch)
314
+ attend_kwargs.update(mask = sliding_mask)
315
+
316
+ # take care of persistent memory key / values
317
+
318
+ pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = k.shape[0])
319
+
271
320
  # persistent memory
272
321
 
273
322
  k = cat((pmk, k), dim = -2)
@@ -275,12 +324,14 @@ class SegmentedAttention(Module):
275
324
 
276
325
  # attention
277
326
 
278
- out, _ = self.attend(q, k, v)
327
+ out, _ = self.attend(q, k, v, **attend_kwargs)
279
328
 
280
329
  out = self.merge_heads(out)
281
330
 
282
331
  out = self.to_out(out)
283
332
 
333
+ out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
334
+
284
335
  out = inverse_segment(out)
285
336
 
286
337
  return out, orig_v
@@ -349,7 +400,8 @@ class MemoryAsContextTransformer(Module):
349
400
  neural_memory_kwargs: dict = dict(),
350
401
  neural_memory_layers: tuple[int, ...] | None = None,
351
402
  aux_kv_recon_loss_weight = 0.,
352
- use_flex_attn = False
403
+ use_flex_attn = False,
404
+ sliding_window_attn = False
353
405
  ):
354
406
  super().__init__()
355
407
 
@@ -366,6 +418,10 @@ class MemoryAsContextTransformer(Module):
366
418
 
367
419
  self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
368
420
 
421
+ # maybe sliding window attn
422
+
423
+ self.sliding_window_attn = sliding_window_attn
424
+
369
425
  # hyper conection
370
426
 
371
427
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
@@ -396,7 +452,8 @@ class MemoryAsContextTransformer(Module):
396
452
  use_flex_attn = use_flex_attn,
397
453
  accept_value_residual = not is_first,
398
454
  num_longterm_mem_tokens = num_longterm_mem_tokens,
399
- num_persist_mem_tokens = num_persist_mem_tokens
455
+ num_persist_mem_tokens = num_persist_mem_tokens,
456
+ sliding = sliding_window_attn
400
457
  )
401
458
 
402
459
  mem = None
@@ -489,7 +546,7 @@ class MemoryAsContextTransformer(Module):
489
546
  flex_attn_fn = None
490
547
 
491
548
  if use_flex_attn:
492
- block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens)
549
+ block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens, self.sliding_window_attn)
493
550
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
494
551
 
495
552
  # value residual
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.63
3
+ Version: 0.0.65
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=7PHBCbeB1LhHY5s3zAyYF0L3Mm7CNy4TOBbcpLX6bNE,17686
4
+ titans_pytorch/titans.py,sha256=y6lJRErIoM6T2aTVFlf1GxSB0cpsmBZdSIj1DCHUCQ8,17486
5
+ titans_pytorch-0.0.65.dist-info/METADATA,sha256=oDjEiufwOninsFDoCGbu691LXc1mey2OT7j6PNzkz0Q,4457
6
+ titans_pytorch-0.0.65.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.65.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.65.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=R9Xu-BjqAe9ZY60IGk4aNXBx_L8THsjJ4QrkbTnRNHo,15346
4
- titans_pytorch/titans.py,sha256=y6lJRErIoM6T2aTVFlf1GxSB0cpsmBZdSIj1DCHUCQ8,17486
5
- titans_pytorch-0.0.63.dist-info/METADATA,sha256=-CImQ-4hVNDFWczTb0V1dWL0QkHS-1c6XyntI1ULrms,4457
6
- titans_pytorch-0.0.63.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.63.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.63.dist-info/RECORD,,