titans-pytorch 0.1.15__tar.gz → 0.1.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.15
3
+ Version: 0.1.18
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: axial-positional-embedding>=0.3.9
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
- Requires-Dist: hyper-connections>=0.1.8
41
+ Requires-Dist: hyper-connections>=0.1.9
42
42
  Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.15"
3
+ version = "0.1.18"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -29,7 +29,7 @@ dependencies = [
29
29
  "axial_positional_embedding>=0.3.9",
30
30
  "einops>=0.8.0",
31
31
  "einx>=0.3.0",
32
- "hyper-connections>=0.1.8",
32
+ "hyper-connections>=0.1.9",
33
33
  "Ninja",
34
34
  "rotary-embedding-torch",
35
35
  "tensordict",
@@ -3,7 +3,7 @@ from torch import nn
3
3
 
4
4
  import pytest
5
5
  from titans_pytorch import NeuralMemory
6
- from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention
6
+ from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention, MemoryAsContextTransformer
7
7
 
8
8
  def exists(v):
9
9
  return v is not None
@@ -92,8 +92,6 @@ def test_mac(
92
92
  num_longterm_mem_tokens,
93
93
  neural_mem_gate_attn_output
94
94
  ):
95
- from titans_pytorch.mac_transformer import MemoryAsContextTransformer
96
-
97
95
  transformer = MemoryAsContextTransformer(
98
96
  num_tokens = 256,
99
97
  dim = 256,
@@ -109,6 +107,25 @@ def test_mac(
109
107
  logits = transformer(x)
110
108
  assert logits.shape == (1, seq_len, 256)
111
109
 
110
+ def test_mac_sampling():
111
+ transformer = MemoryAsContextTransformer(
112
+ num_tokens = 256,
113
+ dim = 256,
114
+ depth = 2,
115
+ segment_len = 32,
116
+ num_persist_mem_tokens = 4,
117
+ num_longterm_mem_tokens = 16,
118
+ )
119
+
120
+ ids = torch.randint(0, 256, (1, 1023))
121
+
122
+ # after much training
123
+
124
+ sampled = transformer.sample(ids[:, :4], 53, use_cache = False, temperature = 0.)
125
+ sampled_with_cache = transformer.sample(ids[:, :4], 53, use_cache = True, temperature = 0.)
126
+
127
+ assert torch.allclose(sampled, sampled_with_cache)
128
+
112
129
  @pytest.mark.parametrize('seq_len', (1023, 17))
113
130
  @pytest.mark.parametrize('sliding', (True, False))
114
131
  def test_flex(
@@ -217,7 +217,8 @@ class SegmentedAttention(Module):
217
217
  self,
218
218
  seq,
219
219
  value_residual = None,
220
- flex_attn_fn: Callable | None = None
220
+ flex_attn_fn: Callable | None = None,
221
+ output_gating = None
221
222
  ):
222
223
 
223
224
  assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
@@ -267,6 +268,9 @@ class SegmentedAttention(Module):
267
268
 
268
269
  out = self.to_out(out)
269
270
 
271
+ if exists(output_gating):
272
+ out = out * output_gating
273
+
270
274
  return out, orig_v
271
275
 
272
276
  def forward(
@@ -274,10 +278,11 @@ class SegmentedAttention(Module):
274
278
  seq,
275
279
  value_residual = None,
276
280
  flex_attn_fn: Callable | None = None,
277
- disable_flex_attn = False
281
+ disable_flex_attn = False,
282
+ output_gating = None
278
283
  ):
279
284
  if seq.is_cuda and self.use_flex_attn and not disable_flex_attn:
280
- return self.forward_flex(seq, value_residual, flex_attn_fn)
285
+ return self.forward_flex(seq, value_residual, flex_attn_fn, output_gating = output_gating)
281
286
 
282
287
  assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
283
288
 
@@ -361,50 +366,10 @@ class SegmentedAttention(Module):
361
366
 
362
367
  out = inverse_segment(out)
363
368
 
364
- return out, orig_v
365
-
366
- # Attention + Neural Memory gating configuration, as depicted in Figure 2
367
-
368
- class NeuralMemoryGatingWrapper(Module):
369
- def __init__(
370
- self,
371
- dim,
372
- attn: SegmentedAttention,
373
- neural_mem: NeuralMemory | None = None,
374
- gate_attn_output = True
375
- ):
376
- super().__init__()
377
- self.attn = attn
378
- self.neural_mem = neural_mem
379
- self.gate_attn_output = gate_attn_output
380
-
381
- def forward(
382
- self,
383
- seq,
384
- *args,
385
- **kwargs
386
- ):
387
- batch, seq_len = seq.shape[:2]
388
- mem = self.neural_mem
389
-
390
- if not exists(mem):
391
- return self.attn(seq, *args, **kwargs), 0.
392
-
393
- # initial retrieve, still should store first, it doesn't make sense not to, unless if all layers share the same neural memory
394
-
395
- retrieved, kv_aux_loss = mem(seq, return_aux_kv_loss = True)
396
-
397
- if not self.gate_attn_output:
398
- seq = seq + retrieved
399
-
400
- # attention
401
-
402
- attn_out, values = self.attn(seq, *args, **kwargs)
403
-
404
- if self.gate_attn_output:
405
- attn_out = attn_out * retrieved.sigmoid()
369
+ if exists(output_gating):
370
+ out = out * output_gating
406
371
 
407
- return (attn_out, values), kv_aux_loss
372
+ return out, orig_v
408
373
 
409
374
  # MAC transformer
410
375
 
@@ -494,16 +459,10 @@ class MemoryAsContextTransformer(Module):
494
459
  **neural_memory_kwargs
495
460
  )
496
461
 
497
- attn = NeuralMemoryGatingWrapper(
498
- dim,
499
- attn = attn,
500
- neural_mem = mem,
501
- gate_attn_output = neural_mem_gate_attn_output
502
- )
503
-
504
462
  ff = FeedForward(dim = dim, mult = ff_mult)
505
463
 
506
464
  self.layers.append(ModuleList([
465
+ init_hyper_conn(dim = dim, branch = mem, add_branch_out_to_residual = not neural_mem_gate_attn_output) if exists(mem) else None,
507
466
  init_hyper_conn(dim = dim, branch = attn),
508
467
  init_hyper_conn(dim = dim, branch = ff)
509
468
  ]))
@@ -512,6 +471,10 @@ class MemoryAsContextTransformer(Module):
512
471
 
513
472
  self.to_logits = LinearNoBias(dim, num_tokens)
514
473
 
474
+ # whether to gate the attention output with the retrieved memories
475
+
476
+ self.gate_attn_output = neural_mem_gate_attn_output
477
+
515
478
  # auxiliary loss on kv recon
516
479
 
517
480
  self.has_aux_kv_recon_loss = aux_kv_recon_loss_weight > 0.
@@ -537,7 +500,8 @@ class MemoryAsContextTransformer(Module):
537
500
  filter_kwargs: dict = dict(
538
501
  min_p = 0.1,
539
502
  ),
540
- show_progress = True
503
+ show_progress = True,
504
+ use_cache = False
541
505
  ):
542
506
  was_training = self.training
543
507
  self.eval()
@@ -547,8 +511,37 @@ class MemoryAsContextTransformer(Module):
547
511
 
548
512
  iter_wrap = tqdm.tqdm if show_progress else identity
549
513
 
514
+ # cache for axial pos, attention, and neural memory
515
+
516
+ cache = None
517
+ factorized_pos_emb = None
518
+
519
+ # precompute factorized pos emb
520
+
521
+ if use_cache:
522
+ round_up_seq_len = round_up_multiple(seq_len, self.segment_len)
523
+ longterm_mem_lens = (round_up_seq_len // self.segment_len) * self.num_longterm_mem_tokens
524
+ seq_len_with_mem = round_up_seq_len + longterm_mem_lens
525
+
526
+ axial_dims = self.axial_pos_emb.maybe_derive_outer_dim(seq_len_with_mem, (self.neural_memory_segment_len,))
527
+
528
+ factorized_pos_emb = self.axial_pos_emb(axial_dims, return_factorized = True)
529
+
530
+ # sample
531
+
550
532
  for _ in iter_wrap(range(sample_num_times)):
551
- logits = self.forward(out, disable_flex_attn = True)
533
+
534
+ logits, next_cache = self.forward(
535
+ out,
536
+ disable_flex_attn = True,
537
+ cache = cache,
538
+ return_cache = True,
539
+ factorized_pos_emb = factorized_pos_emb
540
+ )
541
+
542
+ if use_cache:
543
+ cache = next_cache
544
+
552
545
  logits = logits[:, -1]
553
546
 
554
547
  logits = filter_fn(logits, **filter_kwargs)
@@ -565,7 +558,10 @@ class MemoryAsContextTransformer(Module):
565
558
  x,
566
559
  return_loss = False,
567
560
  return_loss_breakdown = False,
568
- disable_flex_attn = False
561
+ disable_flex_attn = False,
562
+ cache = None,
563
+ return_cache = False,
564
+ factorized_pos_emb = None
569
565
  ):
570
566
 
571
567
  if return_loss:
@@ -593,7 +589,7 @@ class MemoryAsContextTransformer(Module):
593
589
  # apply axial positional embedding
594
590
  # so intra and inter segment can be more easily discerned by the network
595
591
 
596
- pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,))
592
+ pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,), factorized = factorized_pos_emb)
597
593
 
598
594
  x = x + pos_emb
599
595
 
@@ -619,19 +615,34 @@ class MemoryAsContextTransformer(Module):
619
615
 
620
616
  x = self.expand_streams(x)
621
617
 
622
- for attn, ff in self.layers:
618
+ for mem, attn, ff in self.layers:
619
+
620
+ retrieved = None
621
+ attn_out_gates = None
622
+
623
+ if exists(mem):
624
+ retrieved, mem_kv_aux_loss = mem(x, return_aux_kv_loss = True)
625
+ kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
623
626
 
624
- (x, values), maybe_mem_kv_aux_loss = attn(
627
+ if self.gate_attn_output:
628
+ attn_out_gates = retrieved.sigmoid()
629
+ else:
630
+ seq = retrieved
631
+
632
+ # attention
633
+
634
+ x, values = attn(
625
635
  x,
626
636
  value_residual = value_residual,
627
637
  disable_flex_attn = disable_flex_attn,
628
- flex_attn_fn = flex_attn_fn
638
+ flex_attn_fn = flex_attn_fn,
639
+ output_gating = attn_out_gates
629
640
  )
630
641
 
631
- kv_recon_losses = kv_recon_losses + maybe_mem_kv_aux_loss
632
-
633
642
  value_residual = default(value_residual, values)
634
643
 
644
+ # feedforward
645
+
635
646
  x = ff(x)
636
647
 
637
648
  x = self.reduce_streams(x)
@@ -651,7 +662,10 @@ class MemoryAsContextTransformer(Module):
651
662
  logits = self.to_logits(x)
652
663
 
653
664
  if not return_loss:
654
- return logits
665
+ if not return_cache:
666
+ return logits
667
+
668
+ return logits, cache
655
669
 
656
670
  ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
657
671
 
File without changes