titans-pytorch 0.1.17__tar.gz → 0.1.20__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.17
3
+ Version: 0.1.20
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: axial-positional-embedding>=0.3.9
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
- Requires-Dist: hyper-connections>=0.1.8
41
+ Requires-Dist: hyper-connections>=0.1.9
42
42
  Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.17"
3
+ version = "0.1.20"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -29,7 +29,7 @@ dependencies = [
29
29
  "axial_positional_embedding>=0.3.9",
30
30
  "einops>=0.8.0",
31
31
  "einx>=0.3.0",
32
- "hyper-connections>=0.1.8",
32
+ "hyper-connections>=0.1.9",
33
33
  "Ninja",
34
34
  "rotary-embedding-torch",
35
35
  "tensordict",
@@ -107,14 +107,18 @@ def test_mac(
107
107
  logits = transformer(x)
108
108
  assert logits.shape == (1, seq_len, 256)
109
109
 
110
- def test_mac_sampling():
110
+ @pytest.mark.parametrize('sliding', (False, True))
111
+ def test_mac_sampling(sliding):
111
112
  transformer = MemoryAsContextTransformer(
112
113
  num_tokens = 256,
113
114
  dim = 256,
114
115
  depth = 2,
115
116
  segment_len = 32,
116
117
  num_persist_mem_tokens = 4,
117
- num_longterm_mem_tokens = 16,
118
+ num_longterm_mem_tokens = 0,
119
+ sliding_window_attn = sliding,
120
+ neural_memory_layers = (),
121
+ neural_mem_gate_attn_output = False
118
122
  )
119
123
 
120
124
  ids = torch.randint(0, 256, (1, 1023))
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
  from typing import Callable
3
+
3
4
  from math import ceil
4
5
  from functools import partial
6
+ from collections import namedtuple
5
7
 
6
8
  import tqdm
7
9
 
8
10
  import torch
9
- from torch import nn, cat
11
+ from torch import nn, stack, cat
10
12
  import torch.nn.functional as F
11
13
  from torch.nn import Module, ModuleList, Linear
12
14
 
@@ -69,6 +71,8 @@ from titans_pytorch.titans import NeuralMemory
69
71
 
70
72
  LinearNoBias = partial(Linear, bias = False)
71
73
 
74
+ AttnIntermediates = namedtuple('AttnIntermediates', ('value_residual', 'cached_key_values'))
75
+
72
76
  # helpers
73
77
 
74
78
  def exists(v):
@@ -80,6 +84,9 @@ def default(v, d):
80
84
  def identity(t):
81
85
  return t
82
86
 
87
+ def divisible_by(num, den):
88
+ return (num % den) == 0
89
+
83
90
  def round_up_multiple(seq, mult):
84
91
  return ceil(seq / mult) * mult
85
92
 
@@ -111,7 +118,7 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
111
118
 
112
119
  def inverse(out):
113
120
  if fold_into_batch:
114
- out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
121
+ out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
115
122
 
116
123
  if needs_pad:
117
124
  out = out[..., :-padding, :]
@@ -213,11 +220,75 @@ class SegmentedAttention(Module):
213
220
  self.segment_len = segment_len
214
221
  self.num_persist_mem_tokens = num_persist_mem_tokens
215
222
 
223
+ def forward_inference(
224
+ self,
225
+ token,
226
+ cache,
227
+ value_residual = None,
228
+ output_gating = None,
229
+ ):
230
+ batch = token.shape[0]
231
+
232
+ # attention
233
+
234
+ token = self.norm(token)
235
+
236
+ q, k, v = self.to_qkv(token).chunk(3, dim = -1)
237
+ q, k, v = map(self.split_heads, (q, k, v))
238
+
239
+ # value residual
240
+
241
+ orig_v = v
242
+
243
+ if exists(self.to_learned_v_mix):
244
+ mix = self.to_learned_v_mix(token)
245
+ v = v.lerp(value_residual, mix)
246
+
247
+ # caching
248
+
249
+ ck, cv = cache
250
+ k = cat((ck, k), dim = -2)
251
+ v = cat((cv, v), dim = -2)
252
+
253
+ next_cache = (k, v)
254
+
255
+ # relative positions
256
+
257
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
258
+
259
+ # fold
260
+
261
+ q, k, v = tuple(rearrange(t, 'b h n d -> b h n d') for t in (q, k, v))
262
+
263
+ # take care of persistent memory key / values
264
+
265
+ pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = k.shape[0])
266
+
267
+ # persistent memory
268
+
269
+ k = cat((pmk, k), dim = -2)
270
+ v = cat((pmv, v), dim = -2)
271
+
272
+ # attention
273
+
274
+ out, _ = self.attend(q, k, v)
275
+
276
+ out = self.merge_heads(out)
277
+
278
+ out = self.to_out(out)
279
+
280
+ if exists(output_gating):
281
+ out = out * output_gating
282
+
283
+ return out, AttnIntermediates(orig_v, next_cache)
284
+
216
285
  def forward_flex(
217
286
  self,
218
287
  seq,
219
288
  value_residual = None,
220
- flex_attn_fn: Callable | None = None
289
+ flex_attn_fn: Callable | None = None,
290
+ output_gating = None,
291
+ cache = None
221
292
  ):
222
293
 
223
294
  assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
@@ -239,6 +310,10 @@ class SegmentedAttention(Module):
239
310
  mix = self.to_learned_v_mix(seq)
240
311
  v = v.lerp(value_residual, mix)
241
312
 
313
+ # caching
314
+
315
+ next_cache = tuple(map(inverse_segment, (k, v)))
316
+
242
317
  # take care of persistent memory key / values
243
318
 
244
319
  pmk, pmv = repeat(self.persistent_memory, 'kv h n d -> kv b h n d', b = batch)
@@ -267,17 +342,28 @@ class SegmentedAttention(Module):
267
342
 
268
343
  out = self.to_out(out)
269
344
 
270
- return out, orig_v
345
+ if exists(output_gating):
346
+ out = out * output_gating
347
+
348
+ return out, AttnIntermediates(orig_v, next_cache)
271
349
 
272
350
  def forward(
273
351
  self,
274
352
  seq,
275
353
  value_residual = None,
276
354
  flex_attn_fn: Callable | None = None,
277
- disable_flex_attn = False
355
+ disable_flex_attn = False,
356
+ output_gating = None,
357
+ cache = None
278
358
  ):
359
+ is_inferencing = exists(cache)
360
+
361
+ if is_inferencing:
362
+ assert seq.shape[-2] == 1
363
+ return self.forward_inference(seq, cache, value_residual, output_gating = output_gating)
364
+
279
365
  if seq.is_cuda and self.use_flex_attn and not disable_flex_attn:
280
- return self.forward_flex(seq, value_residual, flex_attn_fn)
366
+ return self.forward_flex(seq, value_residual, flex_attn_fn, output_gating = output_gating, cache = cache)
281
367
 
282
368
  assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
283
369
 
@@ -305,6 +391,10 @@ class SegmentedAttention(Module):
305
391
  mix = self.to_learned_v_mix(seq)
306
392
  v = v.lerp(value_residual, mix)
307
393
 
394
+ # caching
395
+
396
+ next_cache = tuple(map(inverse_segment, (k, v)))
397
+
308
398
  # relative positions
309
399
 
310
400
  q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
@@ -361,50 +451,10 @@ class SegmentedAttention(Module):
361
451
 
362
452
  out = inverse_segment(out)
363
453
 
364
- return out, orig_v
365
-
366
- # Attention + Neural Memory gating configuration, as depicted in Figure 2
367
-
368
- class NeuralMemoryGatingWrapper(Module):
369
- def __init__(
370
- self,
371
- dim,
372
- attn: SegmentedAttention,
373
- neural_mem: NeuralMemory | None = None,
374
- gate_attn_output = True
375
- ):
376
- super().__init__()
377
- self.attn = attn
378
- self.neural_mem = neural_mem
379
- self.gate_attn_output = gate_attn_output
380
-
381
- def forward(
382
- self,
383
- seq,
384
- *args,
385
- **kwargs
386
- ):
387
- batch, seq_len = seq.shape[:2]
388
- mem = self.neural_mem
389
-
390
- if not exists(mem):
391
- return self.attn(seq, *args, **kwargs), 0.
392
-
393
- # initial retrieve, still should store first, it doesn't make sense not to, unless if all layers share the same neural memory
394
-
395
- retrieved, kv_aux_loss = mem(seq, return_aux_kv_loss = True)
396
-
397
- if not self.gate_attn_output:
398
- seq = seq + retrieved
454
+ if exists(output_gating):
455
+ out = out * output_gating
399
456
 
400
- # attention
401
-
402
- attn_out, values = self.attn(seq, *args, **kwargs)
403
-
404
- if self.gate_attn_output:
405
- attn_out = attn_out * retrieved.sigmoid()
406
-
407
- return (attn_out, values), kv_aux_loss
457
+ return out, AttnIntermediates(orig_v, next_cache)
408
458
 
409
459
  # MAC transformer
410
460
 
@@ -448,6 +498,7 @@ class MemoryAsContextTransformer(Module):
448
498
  # maybe sliding window attn
449
499
 
450
500
  self.sliding_window_attn = sliding_window_attn
501
+ self.attn_window_size = segment_len + num_longterm_mem_tokens
451
502
 
452
503
  # hyper conection
453
504
 
@@ -494,16 +545,10 @@ class MemoryAsContextTransformer(Module):
494
545
  **neural_memory_kwargs
495
546
  )
496
547
 
497
- attn = NeuralMemoryGatingWrapper(
498
- dim,
499
- attn = attn,
500
- neural_mem = mem,
501
- gate_attn_output = neural_mem_gate_attn_output
502
- )
503
-
504
548
  ff = FeedForward(dim = dim, mult = ff_mult)
505
549
 
506
550
  self.layers.append(ModuleList([
551
+ init_hyper_conn(dim = dim, branch = mem, add_branch_out_to_residual = not neural_mem_gate_attn_output) if exists(mem) else None,
507
552
  init_hyper_conn(dim = dim, branch = attn),
508
553
  init_hyper_conn(dim = dim, branch = ff)
509
554
  ]))
@@ -512,6 +557,10 @@ class MemoryAsContextTransformer(Module):
512
557
 
513
558
  self.to_logits = LinearNoBias(dim, num_tokens)
514
559
 
560
+ # whether to gate the attention output with the retrieved memories
561
+
562
+ self.gate_attn_output = neural_mem_gate_attn_output
563
+
515
564
  # auxiliary loss on kv recon
516
565
 
517
566
  self.has_aux_kv_recon_loss = aux_kv_recon_loss_weight > 0.
@@ -524,7 +573,6 @@ class MemoryAsContextTransformer(Module):
524
573
  assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
525
574
  self.use_flex_attn = use_flex_attn
526
575
 
527
- self.segment_len = segment_len
528
576
  self.num_persist_mem_tokens = num_persist_mem_tokens
529
577
 
530
578
  @torch.no_grad()
@@ -606,7 +654,7 @@ class MemoryAsContextTransformer(Module):
606
654
 
607
655
  # math
608
656
 
609
- batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens
657
+ batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
610
658
 
611
659
  # token embedding
612
660
 
@@ -640,6 +688,12 @@ class MemoryAsContextTransformer(Module):
640
688
  block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens, self.sliding_window_attn)
641
689
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
642
690
 
691
+ # kv caching
692
+
693
+ is_inferencing = exists(cache)
694
+ cache = iter(default(cache, []))
695
+ next_kv_caches = []
696
+
643
697
  # value residual
644
698
 
645
699
  value_residual = None
@@ -648,23 +702,48 @@ class MemoryAsContextTransformer(Module):
648
702
 
649
703
  kv_recon_losses = self.zero
650
704
 
705
+ # when inferencing, only do one token at a time
706
+
707
+ if is_inferencing:
708
+ x = x[:, -1:]
709
+
651
710
  # expand and reduce streams for hyper connections
652
711
 
653
712
  x = self.expand_streams(x)
654
713
 
655
- for attn, ff in self.layers:
714
+ for mem, attn, ff in self.layers:
656
715
 
657
- (x, values), maybe_mem_kv_aux_loss = attn(
716
+ retrieved = None
717
+ attn_out_gates = None
718
+
719
+ # maybe neural memory
720
+
721
+ if exists(mem):
722
+ retrieved, mem_kv_aux_loss = mem(x, return_aux_kv_loss = True)
723
+ kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
724
+
725
+ if self.gate_attn_output:
726
+ attn_out_gates = retrieved.sigmoid()
727
+ else:
728
+ seq = retrieved
729
+
730
+ # attention
731
+
732
+ x, (values, next_kv_cache) = attn(
658
733
  x,
659
734
  value_residual = value_residual,
660
735
  disable_flex_attn = disable_flex_attn,
661
- flex_attn_fn = flex_attn_fn
736
+ flex_attn_fn = flex_attn_fn,
737
+ output_gating = attn_out_gates,
738
+ cache = next(cache, None)
662
739
  )
663
740
 
664
- kv_recon_losses = kv_recon_losses + maybe_mem_kv_aux_loss
665
-
666
741
  value_residual = default(value_residual, values)
667
742
 
743
+ next_kv_caches.append(next_kv_cache)
744
+
745
+ # feedforward
746
+
668
747
  x = ff(x)
669
748
 
670
749
  x = self.reduce_streams(x)
@@ -687,7 +766,16 @@ class MemoryAsContextTransformer(Module):
687
766
  if not return_cache:
688
767
  return logits
689
768
 
690
- return logits, cache
769
+ next_kv_caches = stack([stack(kv_cache) for kv_cache in next_kv_caches])
770
+
771
+ # handle kv cache length depending on local attention type
772
+
773
+ next_kv_caches = next_kv_caches[..., -attn_window_size:, :]
774
+
775
+ if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
776
+ next_kv_caches = next_kv_caches[..., 0:0, :]
777
+
778
+ return logits, next_kv_caches
691
779
 
692
780
  ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
693
781
 
File without changes