titans-pytorch 0.0.50__tar.gz → 0.0.52__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.50
3
+ Version: 0.0.52
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -135,3 +135,12 @@ $ python train_mac.py
135
135
  year = {2024}
136
136
  }
137
137
  ```
138
+
139
+ ```bibtex
140
+ @inproceedings{Yang2024GatedDN,
141
+ title = {Gated Delta Networks: Improving Mamba2 with Delta Rule},
142
+ author = {Songlin Yang and Jan Kautz and Ali Hatamizadeh},
143
+ year = {2024},
144
+ url = {https://api.semanticscholar.org/CorpusID:274598177}
145
+ }
146
+ ```
@@ -82,3 +82,12 @@ $ python train_mac.py
82
82
  year = {2024}
83
83
  }
84
84
  ```
85
+
86
+ ```bibtex
87
+ @inproceedings{Yang2024GatedDN,
88
+ title = {Gated Delta Networks: Improving Mamba2 with Delta Rule},
89
+ author = {Songlin Yang and Jan Kautz and Ali Hatamizadeh},
90
+ year = {2024},
91
+ url = {https://api.semanticscholar.org/CorpusID:274598177}
92
+ }
93
+ ```
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.50"
3
+ version = "0.0.52"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,16 +1,21 @@
1
1
  import torch
2
+ from torch import nn
3
+
2
4
  import pytest
3
5
  from titans_pytorch import NeuralMemory
4
6
 
5
7
  @pytest.mark.parametrize('seq_len', (32, 1024, 77))
8
+ @pytest.mark.parametrize('silu', (False, True))
6
9
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
7
10
  def test_titans(
8
11
  seq_len,
9
- max_grad_norm
12
+ silu,
13
+ max_grad_norm,
10
14
  ):
11
15
  mem = NeuralMemory(
12
16
  dim = 384,
13
17
  chunk_size = 64,
18
+ activation = nn.SiLU() if silu else None,
14
19
  max_grad_norm = max_grad_norm
15
20
  )
16
21
 
@@ -7,16 +7,48 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
+ # flex attention
11
+ # https://pytorch.org/blog/flexattention/
12
+
13
+ flex_attention = None
14
+
15
+ try:
16
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
17
+ if torch.cuda.is_available():
18
+ flex_attention = torch.compile(flex_attention)
19
+ except ImportError:
20
+ pass
21
+
22
+ def create_mac_block_mask(seq_len, window_size, persist_mem_len):
23
+
24
+ def create_mac_mask(b, h, q_idx, kv_idx):
25
+ is_persist_mem = kv_idx < persist_mem_len
26
+ causal_mask = q_idx >= (kv_idx - is_persist_mem)
27
+ block_diagonal = (q_idx // window_size) == ((kv_idx - is_persist_mem) // window_size)
28
+ return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
29
+
30
+ block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
31
+ return block_mask
32
+
33
+ # einstein notation related
34
+
10
35
  from einops import einsum, repeat, rearrange, pack, unpack
11
36
  from einops.layers.torch import Rearrange
12
37
 
13
- from hyper_connections import get_init_and_expand_reduce_stream_functions
38
+ # b - batch
39
+ # n - sequence
40
+ # h - heads
41
+ # d - feature dimension
14
42
 
15
43
  # absolute and relative positions
16
44
 
17
45
  from axial_positional_embedding import ContinuousAxialPositionalEmbedding
18
46
  from rotary_embedding_torch import RotaryEmbedding
47
+
48
+ # hyper connections / attend from x-transformers, which handles different queries and key lengths better
49
+
19
50
  from x_transformers.attend import Attend
51
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
20
52
 
21
53
  # proposed neural memory
22
54
 
@@ -96,6 +128,7 @@ class SegmentedAttention(Module):
96
128
  heads = 8,
97
129
  accept_value_residual = False,
98
130
  attend_kwargs: dict = dict(),
131
+ use_flex_attn = False
99
132
  ):
100
133
  super().__init__()
101
134
  self.norm = nn.RMSNorm(dim)
@@ -125,11 +158,79 @@ class SegmentedAttention(Module):
125
158
 
126
159
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
127
160
 
161
+ # flex attn related
162
+
163
+ assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
164
+ self.use_flex_attn = use_flex_attn
165
+
166
+ self.segment_len = segment_len
167
+ self.num_persist_mem_tokens = num_persist_mem_tokens
168
+
169
+ def forward_flex(
170
+ self,
171
+ seq,
172
+ value_residual = None,
173
+ flex_attn_fn: Callable | None = None
174
+ ):
175
+
176
+ assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
177
+
178
+ batch, seq_len = seq.shape[:2]
179
+
180
+ # attention
181
+
182
+ seq = self.norm(seq)
183
+
184
+ q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
185
+ q, k, v = map(self.split_heads, (q, k, v))
186
+
187
+ # value residual
188
+
189
+ orig_v = v
190
+
191
+ if exists(self.to_learned_v_mix):
192
+ mix = self.to_learned_v_mix(seq)
193
+ v = v.lerp(value_residual, mix)
194
+
195
+ # take care of persistent memory key / values
196
+
197
+ pmk, pmv = repeat(self.persistent_memory, 'kv h n d -> kv b h n d', b = batch)
198
+
199
+ # relative positions
200
+
201
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
202
+
203
+ # persistent memory
204
+
205
+ k = cat((pmk, k), dim = -2)
206
+ v = cat((pmv, v), dim = -2)
207
+
208
+ # prep flex attention
209
+
210
+ if not exists(flex_attn_fn):
211
+ block_mask = create_mac_block_mask(seq_len, self.segment_len, self.num_persist_mem_tokens)
212
+
213
+ flex_attn_fn = partial(flex_attention, block_mask = block_mask)
214
+
215
+ # attention
216
+
217
+ out = flex_attn_fn(q, k, v)
218
+
219
+ out = self.merge_heads(out)
220
+
221
+ out = self.to_out(out)
222
+
223
+ return out, orig_v
224
+
128
225
  def forward(
129
226
  self,
130
227
  seq,
131
- value_residual = None
228
+ value_residual = None,
229
+ flex_attn_fn: Callable | None = None
132
230
  ):
231
+ if seq.is_cuda and self.use_flex_attn:
232
+ return self.forward_flex(seq, value_residual, flex_attn_fn)
233
+
133
234
  assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
134
235
 
135
236
  segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
@@ -159,7 +260,7 @@ class SegmentedAttention(Module):
159
260
 
160
261
  # take care of persistent memory key / values
161
262
 
162
- pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
263
+ pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = seq.shape[0])
163
264
 
164
265
  # relative positions
165
266
 
@@ -56,6 +56,17 @@ def pack_one_with_inverse(t, pattern):
56
56
 
57
57
  return packed, inverse
58
58
 
59
+ def Sequential(*modules):
60
+ modules = [*filter(exists, modules)]
61
+
62
+ if len(modules) == 0:
63
+ return nn.Identity()
64
+
65
+ if len(modules) == 1:
66
+ return modules[0]
67
+
68
+ return nn.Sequential(*modules)
69
+
59
70
  # softclamping gradients
60
71
 
61
72
  def softclamp_max(t, max_value):
@@ -124,9 +135,6 @@ class MemoryAttention(Module):
124
135
  ])
125
136
 
126
137
  def forward(self, x):
127
-
128
- assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
129
-
130
138
  wq, wk, wv, ffw1, ffw2 = self.weights
131
139
 
132
140
  q = F.normalize(x @ wq, dim = -1)
@@ -168,6 +176,7 @@ class NeuralMemory(Module):
168
176
  post_rmsnorm = True,
169
177
  max_grad_norm: float | None = None,
170
178
  use_accelerated_scan = False,
179
+ activation: Module | None = None,
171
180
  default_model_kwargs: dict = dict(
172
181
  depth = 2
173
182
  )
@@ -225,11 +234,11 @@ class NeuralMemory(Module):
225
234
 
226
235
  # queries for retrieving from the model
227
236
 
228
- self.to_queries = LinearNoBias(dim, dim_inner)
237
+ self.to_queries = Sequential(LinearNoBias(dim, dim_inner), activation)
229
238
 
230
239
  # keys and values for storing to the model
231
240
 
232
- self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
241
+ self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
233
242
  self.store_memory_loss_fn = store_memory_loss_fn
234
243
 
235
244
  # empty memory embed
@@ -4,7 +4,7 @@ import gzip
4
4
  import numpy as np
5
5
 
6
6
  import torch
7
- from torch import nn
7
+ from torch import nn, Tensor
8
8
  from torch.optim import Adam
9
9
  from torch.nn import functional as F
10
10
  from torch.utils.data import DataLoader, Dataset
@@ -19,12 +19,13 @@ GRADIENT_ACCUMULATE_EVERY = 4
19
19
  LEARNING_RATE = 2e-4
20
20
  VALIDATE_EVERY = 100
21
21
  GENERATE_EVERY = 500
22
+ PRIME_LENGTH = 100
22
23
  GENERATE_LENGTH = 512
23
- SHOULD_GENERATE = False
24
+ SHOULD_GENERATE = True
24
25
  SEQ_LEN = 512
25
26
 
26
27
  PROJECT_NAME = 'titans-mac-transformer'
27
- WANDB_ONLINE = True # turn this on to pipe experiment to cloud
28
+ WANDB_ONLINE = False # turn this on to pipe experiment to cloud
28
29
  NEURAL_MEMORY_DEPTH = 2
29
30
  NUM_PERSIST_MEM = 4
30
31
  NUM_LONGTERM_MEM = 4
@@ -53,6 +54,52 @@ def decode_token(token):
53
54
  def decode_tokens(tokens):
54
55
  return ''.join(list(map(decode_token, tokens)))
55
56
 
57
+ # sampling helpers
58
+
59
+ def log(t, eps = 1e-20):
60
+ return torch.log(t.clamp(min = eps))
61
+
62
+ def gumbel_noise(t):
63
+ noise = torch.zeros_like(t).uniform_(0, 1)
64
+ return -log(-log(noise))
65
+
66
+ def gumbel_sample(t, temperature = 1., keepdim = True):
67
+ if temperature <= 0.:
68
+ return t.argmax(dim = dim, keepdim = keepdim)
69
+
70
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = -1, keepdim = keepdim)
71
+
72
+ # min_p
73
+ # https://arxiv.org/abs/2407.01082
74
+
75
+ def min_p_filter(logits, min_p = 0.1):
76
+ probs = logits.softmax(dim = -1)
77
+ max_probs = probs.amax(dim = -1, keepdim = True)
78
+ limit = min_p * max_probs
79
+ return torch.where(probs < limit, float('-inf'), logits)
80
+
81
+ def base_decoding(
82
+ net,
83
+ prompt: Tensor,
84
+ seq_len: int,
85
+ temperature = 1.5,
86
+ min_p = 1e-1,
87
+ filter_thres = 0.9,
88
+ ):
89
+ prompt_seq_len, out = prompt.shape[-1], prompt.clone()
90
+ sample_num_times = max(0, seq_len - prompt_seq_len)
91
+
92
+ for _ in tqdm.tqdm(range(sample_num_times)):
93
+ logits = net(out)
94
+ logits = logits[:, -1]
95
+
96
+ logits = min_p_filter(logits, min_p = min_p)
97
+ sample = gumbel_sample(logits, temperature = temperature)
98
+
99
+ out = torch.cat((out, sample), dim = -1)
100
+
101
+ return out[..., prompt_seq_len:]
102
+
56
103
  # instantiate memory-as-context transformer
57
104
 
58
105
  model = MemoryAsContextTransformer(
@@ -127,10 +174,10 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
127
174
 
128
175
  if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
129
176
  model.eval()
130
- inp = random.choice(val_dataset)[:-1]
177
+ inp = random.choice(val_dataset)[:PRIME_LENGTH]
131
178
  prime = decode_tokens(inp)
132
179
  print(f'%s \n\n %s', (prime, '*' * 100))
133
180
 
134
- sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
181
+ sample = base_decoding(model, inp[None, ...], GENERATE_LENGTH)
135
182
  output_str = decode_tokens(sample[0])
136
183
  print(output_str)
File without changes