titans-pytorch 0.0.50__tar.gz → 0.0.51__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.50
3
+ Version: 0.0.51
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -135,3 +135,12 @@ $ python train_mac.py
135
135
  year = {2024}
136
136
  }
137
137
  ```
138
+
139
+ ```bibtex
140
+ @inproceedings{Yang2024GatedDN,
141
+ title = {Gated Delta Networks: Improving Mamba2 with Delta Rule},
142
+ author = {Songlin Yang and Jan Kautz and Ali Hatamizadeh},
143
+ year = {2024},
144
+ url = {https://api.semanticscholar.org/CorpusID:274598177}
145
+ }
146
+ ```
@@ -82,3 +82,12 @@ $ python train_mac.py
82
82
  year = {2024}
83
83
  }
84
84
  ```
85
+
86
+ ```bibtex
87
+ @inproceedings{Yang2024GatedDN,
88
+ title = {Gated Delta Networks: Improving Mamba2 with Delta Rule},
89
+ author = {Songlin Yang and Jan Kautz and Ali Hatamizadeh},
90
+ year = {2024},
91
+ url = {https://api.semanticscholar.org/CorpusID:274598177}
92
+ }
93
+ ```
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.50"
3
+ version = "0.0.51"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,16 +1,21 @@
1
1
  import torch
2
+ from torch import nn
3
+
2
4
  import pytest
3
5
  from titans_pytorch import NeuralMemory
4
6
 
5
7
  @pytest.mark.parametrize('seq_len', (32, 1024, 77))
8
+ @pytest.mark.parametrize('silu', (False, True))
6
9
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
7
10
  def test_titans(
8
11
  seq_len,
9
- max_grad_norm
12
+ silu,
13
+ max_grad_norm,
10
14
  ):
11
15
  mem = NeuralMemory(
12
16
  dim = 384,
13
17
  chunk_size = 64,
18
+ activation = nn.SiLU() if silu else None,
14
19
  max_grad_norm = max_grad_norm
15
20
  )
16
21
 
@@ -7,16 +7,48 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
+ # flex attention
11
+ # https://pytorch.org/blog/flexattention/
12
+
13
+ flex_attention = None
14
+
15
+ try:
16
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
17
+ if torch.cuda.is_available():
18
+ flex_attention = torch.compile(flex_attention)
19
+ except ImportError:
20
+ pass
21
+
22
+ def create_mac_block_mask(seq_len, window_size, persist_mem_len):
23
+
24
+ def create_mac_mask(b, h, q_idx, kv_idx):
25
+ is_persist_mem = kv_idx < persist_mem_len
26
+ causal_mask = q_idx >= (kv_idx - is_persist_mem)
27
+ block_diagonal = (q_idx // window_size) == ((kv_idx - is_persist_mem) // window_size)
28
+ return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
29
+
30
+ block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
31
+ return block_mask
32
+
33
+ # einstein notation related
34
+
10
35
  from einops import einsum, repeat, rearrange, pack, unpack
11
36
  from einops.layers.torch import Rearrange
12
37
 
13
- from hyper_connections import get_init_and_expand_reduce_stream_functions
38
+ # b - batch
39
+ # n - sequence
40
+ # h - heads
41
+ # d - feature dimension
14
42
 
15
43
  # absolute and relative positions
16
44
 
17
45
  from axial_positional_embedding import ContinuousAxialPositionalEmbedding
18
46
  from rotary_embedding_torch import RotaryEmbedding
47
+
48
+ # hyper connections / attend from x-transformers, which handles different queries and key lengths better
49
+
19
50
  from x_transformers.attend import Attend
51
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
20
52
 
21
53
  # proposed neural memory
22
54
 
@@ -56,6 +56,17 @@ def pack_one_with_inverse(t, pattern):
56
56
 
57
57
  return packed, inverse
58
58
 
59
+ def Sequential(*modules):
60
+ modules = [*filter(exists, modules)]
61
+
62
+ if len(modules) == 0:
63
+ return nn.Identity()
64
+
65
+ if len(modules) == 1:
66
+ return modules[0]
67
+
68
+ return nn.Sequential(*modules)
69
+
59
70
  # softclamping gradients
60
71
 
61
72
  def softclamp_max(t, max_value):
@@ -124,9 +135,6 @@ class MemoryAttention(Module):
124
135
  ])
125
136
 
126
137
  def forward(self, x):
127
-
128
- assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
129
-
130
138
  wq, wk, wv, ffw1, ffw2 = self.weights
131
139
 
132
140
  q = F.normalize(x @ wq, dim = -1)
@@ -168,6 +176,7 @@ class NeuralMemory(Module):
168
176
  post_rmsnorm = True,
169
177
  max_grad_norm: float | None = None,
170
178
  use_accelerated_scan = False,
179
+ activation: Module | None = None,
171
180
  default_model_kwargs: dict = dict(
172
181
  depth = 2
173
182
  )
@@ -225,11 +234,11 @@ class NeuralMemory(Module):
225
234
 
226
235
  # queries for retrieving from the model
227
236
 
228
- self.to_queries = LinearNoBias(dim, dim_inner)
237
+ self.to_queries = Sequential(LinearNoBias(dim, dim_inner), activation)
229
238
 
230
239
  # keys and values for storing to the model
231
240
 
232
- self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
241
+ self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
233
242
  self.store_memory_loss_fn = store_memory_loss_fn
234
243
 
235
244
  # empty memory embed
@@ -4,7 +4,7 @@ import gzip
4
4
  import numpy as np
5
5
 
6
6
  import torch
7
- from torch import nn
7
+ from torch import nn, Tensor
8
8
  from torch.optim import Adam
9
9
  from torch.nn import functional as F
10
10
  from torch.utils.data import DataLoader, Dataset
@@ -19,12 +19,13 @@ GRADIENT_ACCUMULATE_EVERY = 4
19
19
  LEARNING_RATE = 2e-4
20
20
  VALIDATE_EVERY = 100
21
21
  GENERATE_EVERY = 500
22
+ PRIME_LENGTH = 100
22
23
  GENERATE_LENGTH = 512
23
- SHOULD_GENERATE = False
24
+ SHOULD_GENERATE = True
24
25
  SEQ_LEN = 512
25
26
 
26
27
  PROJECT_NAME = 'titans-mac-transformer'
27
- WANDB_ONLINE = True # turn this on to pipe experiment to cloud
28
+ WANDB_ONLINE = False # turn this on to pipe experiment to cloud
28
29
  NEURAL_MEMORY_DEPTH = 2
29
30
  NUM_PERSIST_MEM = 4
30
31
  NUM_LONGTERM_MEM = 4
@@ -53,6 +54,52 @@ def decode_token(token):
53
54
  def decode_tokens(tokens):
54
55
  return ''.join(list(map(decode_token, tokens)))
55
56
 
57
+ # sampling helpers
58
+
59
+ def log(t, eps = 1e-20):
60
+ return torch.log(t.clamp(min = eps))
61
+
62
+ def gumbel_noise(t):
63
+ noise = torch.zeros_like(t).uniform_(0, 1)
64
+ return -log(-log(noise))
65
+
66
+ def gumbel_sample(t, temperature = 1., keepdim = True):
67
+ if temperature <= 0.:
68
+ return t.argmax(dim = dim, keepdim = keepdim)
69
+
70
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = -1, keepdim = keepdim)
71
+
72
+ # min_p
73
+ # https://arxiv.org/abs/2407.01082
74
+
75
+ def min_p_filter(logits, min_p = 0.1):
76
+ probs = logits.softmax(dim = -1)
77
+ max_probs = probs.amax(dim = -1, keepdim = True)
78
+ limit = min_p * max_probs
79
+ return torch.where(probs < limit, float('-inf'), logits)
80
+
81
+ def base_decoding(
82
+ net,
83
+ prompt: Tensor,
84
+ seq_len: int,
85
+ temperature = 1.5,
86
+ min_p = 1e-1,
87
+ filter_thres = 0.9,
88
+ ):
89
+ prompt_seq_len, out = prompt.shape[-1], prompt.clone()
90
+ sample_num_times = max(0, seq_len - prompt_seq_len)
91
+
92
+ for _ in tqdm.tqdm(range(sample_num_times)):
93
+ logits = net(out)
94
+ logits = logits[:, -1]
95
+
96
+ logits = min_p_filter(logits, min_p = min_p)
97
+ sample = gumbel_sample(logits, temperature = temperature)
98
+
99
+ out = torch.cat((out, sample), dim = -1)
100
+
101
+ return out[..., prompt_seq_len:]
102
+
56
103
  # instantiate memory-as-context transformer
57
104
 
58
105
  model = MemoryAsContextTransformer(
@@ -127,10 +174,10 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
127
174
 
128
175
  if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
129
176
  model.eval()
130
- inp = random.choice(val_dataset)[:-1]
177
+ inp = random.choice(val_dataset)[:PRIME_LENGTH]
131
178
  prime = decode_tokens(inp)
132
179
  print(f'%s \n\n %s', (prime, '*' * 100))
133
180
 
134
- sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
181
+ sample = base_decoding(model, inp[None, ...], GENERATE_LENGTH)
135
182
  output_str = decode_tokens(sample[0])
136
183
  print(output_str)
File without changes