titans-pytorch 0.1.0__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -43,9 +43,9 @@ Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
45
45
  Requires-Dist: torch>=2.2
46
+ Requires-Dist: tqdm
46
47
  Requires-Dist: x-transformers
47
48
  Provides-Extra: examples
48
- Requires-Dist: tqdm; extra == 'examples'
49
49
  Requires-Dist: wandb; extra == 'examples'
50
50
  Provides-Extra: test
51
51
  Requires-Dist: pytest; extra == 'test'
@@ -104,7 +104,12 @@ transformer = MemoryAsContextTransformer(
104
104
 
105
105
  token_ids = torch.randint(0, 256, (1, 1023))
106
106
 
107
- logits = transformer(token_ids) # (1, 1023, 256)
107
+ loss = transformer(token_ids, return_loss = True) # (1, 1023, 256)
108
+ loss.backward()
109
+
110
+ # after much training
111
+
112
+ sampled = transformer.sample(token_ids[:, :4], 512)
108
113
  ```
109
114
 
110
115
  ## Experiments
@@ -51,7 +51,12 @@ transformer = MemoryAsContextTransformer(
51
51
 
52
52
  token_ids = torch.randint(0, 256, (1, 1023))
53
53
 
54
- logits = transformer(token_ids) # (1, 1023, 256)
54
+ loss = transformer(token_ids, return_loss = True) # (1, 1023, 256)
55
+ loss.backward()
56
+
57
+ # after much training
58
+
59
+ sampled = transformer.sample(token_ids[:, :4], 512)
55
60
  ```
56
61
 
57
62
  ## Experiments
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.0"
3
+ version = "0.1.2"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -34,6 +34,7 @@ dependencies = [
34
34
  "rotary-embedding-torch",
35
35
  "tensordict",
36
36
  "torch>=2.2",
37
+ "tqdm",
37
38
  "x-transformers"
38
39
  ]
39
40
 
@@ -44,7 +45,6 @@ Repository = "https://github.com/lucidrains/titans-pytorch"
44
45
  [project.optional-dependencies]
45
46
 
46
47
  examples = [
47
- "tqdm",
48
48
  "wandb"
49
49
  ]
50
50
 
@@ -3,6 +3,10 @@ from torch import nn
3
3
 
4
4
  import pytest
5
5
  from titans_pytorch import NeuralMemory
6
+ from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention
7
+
8
+ def exists(v):
9
+ return v is not None
6
10
 
7
11
  @pytest.mark.parametrize('seq_len', (32, 1024, 77))
8
12
  @pytest.mark.parametrize('silu', (False, True))
@@ -46,10 +50,12 @@ def test_titans_attn_memory():
46
50
 
47
51
  assert seq.shape == retrieved.shape
48
52
 
53
+ @pytest.mark.parametrize('seq_len', (1023, 17))
49
54
  @pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
50
55
  @pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
51
56
  @pytest.mark.parametrize('neural_mem_gate_attn_output', (False, True))
52
57
  def test_mac(
58
+ seq_len,
53
59
  num_persist_mem_tokens,
54
60
  num_longterm_mem_tokens,
55
61
  neural_mem_gate_attn_output
@@ -66,7 +72,32 @@ def test_mac(
66
72
  neural_mem_gate_attn_output = neural_mem_gate_attn_output
67
73
  )
68
74
 
69
- x = torch.randint(0, 256, (1, 1023))
75
+ x = torch.randint(0, 256, (1, seq_len))
70
76
 
71
77
  logits = transformer(x)
72
- assert logits.shape == (1, 1023, 256)
78
+ assert logits.shape == (1, seq_len, 256)
79
+
80
+ @pytest.mark.parametrize('seq_len', (1023, 17))
81
+ @pytest.mark.parametrize('sliding', (True, False))
82
+ def test_flex(
83
+ seq_len,
84
+ sliding
85
+ ):
86
+ if not (torch.cuda.is_available() and exists(flex_attention)):
87
+ pytest.skip()
88
+
89
+ attn = SegmentedAttention(
90
+ dim = 512,
91
+ segment_len = 32,
92
+ num_persist_mem_tokens = 1,
93
+ num_longterm_mem_tokens = 1,
94
+ use_flex_attn = True,
95
+ sliding = sliding
96
+ ).cuda()
97
+
98
+ seq = torch.randn(1, seq_len, 512).cuda()
99
+
100
+ out_flex, _ = attn(seq)
101
+ out_non_flex, _ = attn(seq, disable_flex_attn = True)
102
+
103
+ assert torch.allclose(out_flex, out_non_flex, atol = 1e-5)
@@ -3,6 +3,8 @@ from typing import Callable
3
3
  from math import ceil
4
4
  from functools import partial
5
5
 
6
+ import tqdm
7
+
6
8
  import torch
7
9
  from torch import nn, cat
8
10
  import torch.nn.functional as F
@@ -88,12 +90,6 @@ def pad_at_dim(t, pad, dim = -1, value = 0.):
88
90
 
89
91
  def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
90
92
  batch, seq_len = seq.shape[:2]
91
-
92
- need_segment = seq_len >= segment_len
93
-
94
- if not need_segment:
95
- return seq, identity
96
-
97
93
  next_seq_len_mult = round_up_multiple(seq_len, segment_len)
98
94
 
99
95
  padding = next_seq_len_mult - seq_len
@@ -116,6 +112,29 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
116
112
 
117
113
  return seq, inverse
118
114
 
115
+ # sampling related
116
+
117
+ def log(t, eps = 1e-20):
118
+ return torch.log(t.clamp(min = eps))
119
+
120
+ def gumbel_noise(t):
121
+ noise = torch.rand_like(t)
122
+ return -log(-log(noise))
123
+
124
+ def gumbel_sample(t, temperature = 1.):
125
+ if temperature > 0.:
126
+ t = t / temperature + gumbel_noise(t)
127
+ return t.argmax(dim = -1, keepdim = True)
128
+
129
+ # min_p
130
+ # https://arxiv.org/abs/2407.01082
131
+
132
+ def min_p_filter(logits, min_p = 0.1):
133
+ probs = logits.softmax(dim = -1)
134
+ max_probs = probs.amax(dim = -1, keepdim = True)
135
+ limit = min_p * max_probs
136
+ return torch.where(probs < limit, float('-inf'), logits)
137
+
119
138
  # feedforward and attention
120
139
 
121
140
  class GEGLU(Module):
@@ -500,6 +519,39 @@ class MemoryAsContextTransformer(Module):
500
519
  self.segment_len = segment_len
501
520
  self.num_persist_mem_tokens = num_persist_mem_tokens
502
521
 
522
+ @torch.no_grad()
523
+ def sample(
524
+ self,
525
+ prompt: Tensor,
526
+ seq_len: int,
527
+ temperature = 1.5,
528
+ filter_fn: Callable = min_p_filter,
529
+ filter_kwargs: dict = dict(
530
+ min_p = 0.1,
531
+ ),
532
+ show_progress = True
533
+ ):
534
+ was_training = self.training
535
+ self.eval()
536
+
537
+ prompt_seq_len, out = prompt.shape[-1], prompt.clone()
538
+ sample_num_times = max(0, seq_len - prompt_seq_len)
539
+
540
+ iter_wrap = tqdm.tqdm if show_progress else identity
541
+
542
+ for _ in iter_wrap(range(sample_num_times)):
543
+ logits = self.forward(out, disable_flex_attn = True)
544
+ logits = logits[:, -1]
545
+
546
+ logits = filter_fn(logits, **filter_kwargs)
547
+ sample = gumbel_sample(logits, temperature = temperature)
548
+
549
+ out = torch.cat((out, sample), dim = -1)
550
+
551
+ self.train(was_training)
552
+
553
+ return out[..., prompt_seq_len:]
554
+
503
555
  def forward(
504
556
  self,
505
557
  x,
@@ -592,7 +592,12 @@ class NeuralMemory(Module):
592
592
  batch, seq_len = seq.shape[:2]
593
593
 
594
594
  if seq_len < self.chunk_size:
595
- return self.init_empty_memory_embed(batch, seq_len)
595
+ out = self.init_empty_memory_embed(batch, seq_len)
596
+
597
+ if not return_aux_kv_loss:
598
+ return out
599
+
600
+ return out, self.zero
596
601
 
597
602
  if exists(past_state):
598
603
  past_state = tuple(TensorDict(d) for d in past_state)
@@ -68,52 +68,6 @@ def decode_token(token):
68
68
  def decode_tokens(tokens):
69
69
  return ''.join(list(map(decode_token, tokens)))
70
70
 
71
- # sampling helpers
72
-
73
- def log(t, eps = 1e-20):
74
- return torch.log(t.clamp(min = eps))
75
-
76
- def gumbel_noise(t):
77
- noise = torch.zeros_like(t).uniform_(0, 1)
78
- return -log(-log(noise))
79
-
80
- def gumbel_sample(t, temperature = 1., keepdim = True):
81
- if temperature <= 0.:
82
- return t.argmax(dim = dim, keepdim = keepdim)
83
-
84
- return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = -1, keepdim = keepdim)
85
-
86
- # min_p
87
- # https://arxiv.org/abs/2407.01082
88
-
89
- def min_p_filter(logits, min_p = 0.1):
90
- probs = logits.softmax(dim = -1)
91
- max_probs = probs.amax(dim = -1, keepdim = True)
92
- limit = min_p * max_probs
93
- return torch.where(probs < limit, float('-inf'), logits)
94
-
95
- def base_decoding(
96
- net,
97
- prompt: Tensor,
98
- seq_len: int,
99
- temperature = 1.5,
100
- min_p = 1e-1,
101
- filter_thres = 0.9,
102
- ):
103
- prompt_seq_len, out = prompt.shape[-1], prompt.clone()
104
- sample_num_times = max(0, seq_len - prompt_seq_len)
105
-
106
- for _ in tqdm.tqdm(range(sample_num_times)):
107
- logits = net(out, disable_flex_attn = True)
108
- logits = logits[:, -1]
109
-
110
- logits = min_p_filter(logits, min_p = min_p)
111
- sample = gumbel_sample(logits, temperature = temperature)
112
-
113
- out = torch.cat((out, sample), dim = -1)
114
-
115
- return out[..., prompt_seq_len:]
116
-
117
71
  # instantiate memory-as-context transformer
118
72
 
119
73
  model = MemoryAsContextTransformer(
@@ -197,6 +151,6 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
197
151
  prime = decode_tokens(inp)
198
152
  print(f'%s \n\n %s', (prime, '*' * 100))
199
153
 
200
- sample = base_decoding(model, inp[None, ...], GENERATE_LENGTH)
154
+ sample = model.sample(inp[None, ...], GENERATE_LENGTH)
201
155
  output_str = decode_tokens(sample[0])
202
156
  print(output_str)
@@ -1,18 +0,0 @@
1
- import torch
2
- from titans_pytorch.mac_transformer import SegmentedAttention
3
-
4
- attn = SegmentedAttention(
5
- dim = 512,
6
- segment_len = 32,
7
- num_persist_mem_tokens = 1,
8
- num_longterm_mem_tokens = 1,
9
- use_flex_attn = True,
10
- sliding = False
11
- ).cuda()
12
-
13
- seq = torch.randn(1, 1019, 512).cuda()
14
-
15
- out_flex, _ = attn(seq)
16
- out_non_flex, _ = attn(seq, disable_flex_attn = True)
17
-
18
- assert torch.allclose(out_flex, out_non_flex, atol = 1e-5)
File without changes
File without changes
File without changes