titans-pytorch 0.1.14__tar.gz → 0.1.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.14
3
+ Version: 0.1.17
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -35,7 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
- Requires-Dist: axial-positional-embedding>=0.3.7
38
+ Requires-Dist: axial-positional-embedding>=0.3.9
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: hyper-connections>=0.1.8
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.14"
3
+ version = "0.1.17"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -26,7 +26,7 @@ classifiers=[
26
26
 
27
27
  dependencies = [
28
28
  "accelerated-scan>=0.2.0",
29
- "axial_positional_embedding>=0.3.7",
29
+ "axial_positional_embedding>=0.3.9",
30
30
  "einops>=0.8.0",
31
31
  "einx>=0.3.0",
32
32
  "hyper-connections>=0.1.8",
@@ -3,7 +3,7 @@ from torch import nn
3
3
 
4
4
  import pytest
5
5
  from titans_pytorch import NeuralMemory
6
- from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention
6
+ from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention, MemoryAsContextTransformer
7
7
 
8
8
  def exists(v):
9
9
  return v is not None
@@ -92,8 +92,6 @@ def test_mac(
92
92
  num_longterm_mem_tokens,
93
93
  neural_mem_gate_attn_output
94
94
  ):
95
- from titans_pytorch.mac_transformer import MemoryAsContextTransformer
96
-
97
95
  transformer = MemoryAsContextTransformer(
98
96
  num_tokens = 256,
99
97
  dim = 256,
@@ -109,6 +107,25 @@ def test_mac(
109
107
  logits = transformer(x)
110
108
  assert logits.shape == (1, seq_len, 256)
111
109
 
110
+ def test_mac_sampling():
111
+ transformer = MemoryAsContextTransformer(
112
+ num_tokens = 256,
113
+ dim = 256,
114
+ depth = 2,
115
+ segment_len = 32,
116
+ num_persist_mem_tokens = 4,
117
+ num_longterm_mem_tokens = 16,
118
+ )
119
+
120
+ ids = torch.randint(0, 256, (1, 1023))
121
+
122
+ # after much training
123
+
124
+ sampled = transformer.sample(ids[:, :4], 53, use_cache = False, temperature = 0.)
125
+ sampled_with_cache = transformer.sample(ids[:, :4], 53, use_cache = True, temperature = 0.)
126
+
127
+ assert torch.allclose(sampled, sampled_with_cache)
128
+
112
129
  @pytest.mark.parametrize('seq_len', (1023, 17))
113
130
  @pytest.mark.parametrize('sliding', (True, False))
114
131
  def test_flex(
@@ -537,7 +537,8 @@ class MemoryAsContextTransformer(Module):
537
537
  filter_kwargs: dict = dict(
538
538
  min_p = 0.1,
539
539
  ),
540
- show_progress = True
540
+ show_progress = True,
541
+ use_cache = False
541
542
  ):
542
543
  was_training = self.training
543
544
  self.eval()
@@ -547,8 +548,37 @@ class MemoryAsContextTransformer(Module):
547
548
 
548
549
  iter_wrap = tqdm.tqdm if show_progress else identity
549
550
 
551
+ # cache for axial pos, attention, and neural memory
552
+
553
+ cache = None
554
+ factorized_pos_emb = None
555
+
556
+ # precompute factorized pos emb
557
+
558
+ if use_cache:
559
+ round_up_seq_len = round_up_multiple(seq_len, self.segment_len)
560
+ longterm_mem_lens = (round_up_seq_len // self.segment_len) * self.num_longterm_mem_tokens
561
+ seq_len_with_mem = round_up_seq_len + longterm_mem_lens
562
+
563
+ axial_dims = self.axial_pos_emb.maybe_derive_outer_dim(seq_len_with_mem, (self.neural_memory_segment_len,))
564
+
565
+ factorized_pos_emb = self.axial_pos_emb(axial_dims, return_factorized = True)
566
+
567
+ # sample
568
+
550
569
  for _ in iter_wrap(range(sample_num_times)):
551
- logits = self.forward(out, disable_flex_attn = True)
570
+
571
+ logits, next_cache = self.forward(
572
+ out,
573
+ disable_flex_attn = True,
574
+ cache = cache,
575
+ return_cache = True,
576
+ factorized_pos_emb = factorized_pos_emb
577
+ )
578
+
579
+ if use_cache:
580
+ cache = next_cache
581
+
552
582
  logits = logits[:, -1]
553
583
 
554
584
  logits = filter_fn(logits, **filter_kwargs)
@@ -565,7 +595,10 @@ class MemoryAsContextTransformer(Module):
565
595
  x,
566
596
  return_loss = False,
567
597
  return_loss_breakdown = False,
568
- disable_flex_attn = False
598
+ disable_flex_attn = False,
599
+ cache = None,
600
+ return_cache = False,
601
+ factorized_pos_emb = None
569
602
  ):
570
603
 
571
604
  if return_loss:
@@ -593,7 +626,7 @@ class MemoryAsContextTransformer(Module):
593
626
  # apply axial positional embedding
594
627
  # so intra and inter segment can be more easily discerned by the network
595
628
 
596
- pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,))
629
+ pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,), factorized = factorized_pos_emb)
597
630
 
598
631
  x = x + pos_emb
599
632
 
@@ -651,7 +684,10 @@ class MemoryAsContextTransformer(Module):
651
684
  logits = self.to_logits(x)
652
685
 
653
686
  if not return_loss:
654
- return logits
687
+ if not return_cache:
688
+ return logits
689
+
690
+ return logits, cache
655
691
 
656
692
  ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
657
693
 
@@ -301,6 +301,45 @@ class MemoryAttention(Module):
301
301
 
302
302
  return out
303
303
 
304
+ # associative scan wrapper
305
+
306
+ class AssocScan(Module):
307
+ def __init__(
308
+ self,
309
+ use_accelerated = False
310
+ ):
311
+ super().__init__()
312
+ self.use_accelerated = use_accelerated
313
+
314
+ def forward(self, gates, inputs):
315
+
316
+ if not self.use_accelerated:
317
+ _, outputs = associative_scan(binary_operator, (gates, inputs))
318
+ return outputs
319
+
320
+ from accelerated_scan.triton import scan as triton_scan
321
+ from accelerated_scan.warp import scan as warp_scan
322
+
323
+ scan = triton_scan if gates.is_cuda else warp_scan
324
+
325
+ def accelerate_scan_fn(gates, inputs):
326
+ gates = gates.expand_as(inputs)
327
+ gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
328
+
329
+ seq_len = gates.shape[-1]
330
+ next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
331
+
332
+ gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
333
+ inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
334
+
335
+ outputs = scan(gates.contiguous(), inputs.contiguous())
336
+
337
+ outputs = outputs[..., :seq_len]
338
+ outputs = rearrange(outputs, 'b d n -> b n d')
339
+ return outputs
340
+
341
+ return accelerate_scan_fn(gates, inputs)
342
+
304
343
  # main neural memory
305
344
 
306
345
  def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
@@ -339,6 +378,10 @@ class NeuralMemory(Module):
339
378
 
340
379
  self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
341
380
 
381
+ # associative scan
382
+
383
+ self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan)
384
+
342
385
  # norms
343
386
 
344
387
  self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
@@ -564,38 +607,6 @@ class NeuralMemory(Module):
564
607
 
565
608
  surprises = grads.apply(lambda t: -t)
566
609
 
567
- # determine scan function
568
-
569
- def default_associative_scan(gates, inputs):
570
- _, outputs = associative_scan(binary_operator, (gates, inputs))
571
- return outputs
572
-
573
- if self.use_accelerated_scan:
574
- from accelerated_scan.triton import scan as triton_scan
575
- from accelerated_scan.warp import scan as warp_scan
576
-
577
- scan = triton_scan if seq.is_cuda else warp_scan
578
-
579
- def accelerate_scan_fn(gates, inputs):
580
- gates = gates.expand_as(inputs)
581
- gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
582
-
583
- seq_len = gates.shape[-1]
584
- next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
585
-
586
- gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
587
- inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
588
-
589
- outputs = scan(gates.contiguous(), inputs.contiguous())
590
-
591
- outputs = outputs[..., :seq_len]
592
- outputs = rearrange(outputs, 'b d n -> b n d')
593
- return outputs
594
-
595
- scan_fn = accelerate_scan_fn
596
- else:
597
- scan_fn = default_associative_scan
598
-
599
610
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
600
611
 
601
612
  next_momentum = TensorDict() if has_momentum else None
@@ -610,12 +621,12 @@ class NeuralMemory(Module):
610
621
  # derive momentum with associative scan - eq (10)
611
622
 
612
623
  if has_momentum:
613
- update = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
624
+ update = self.assoc_scan(adaptive_momentum, surprise) # momentum is S / surprise in the paper
614
625
  momentum = update
615
626
 
616
627
  # use associative scan again for learned forgetting (weight decay) - eq (13)
617
628
 
618
- update = scan_fn(1. - decay_factor, update)
629
+ update = self.assoc_scan(1. - decay_factor, update)
619
630
 
620
631
  updates[param_name] = inverse_pack(update)
621
632
 
File without changes