titans-pytorch 0.0.3__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.3
3
+ Version: 0.0.4
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -49,7 +49,7 @@ Description-Content-Type: text/markdown
49
49
 
50
50
  ## Titans - Pytorch (wip)
51
51
 
52
- Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module.
52
+ Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
53
53
 
54
54
  ## Install
55
55
 
@@ -4,7 +4,7 @@
4
4
 
5
5
  ## Titans - Pytorch (wip)
6
6
 
7
- Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module.
7
+ Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
8
8
 
9
9
  ## Install
10
10
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.3"
3
+ version = "0.0.4"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -16,7 +16,7 @@ from titans_pytorch.associative_scan import (
16
16
 
17
17
  import einx
18
18
  from einops import rearrange, pack, unpack
19
- from einops.layers.torch import Rearrange
19
+ from einops.layers.torch import Rearrange, Reduce
20
20
 
21
21
  """
22
22
  ein notation:
@@ -84,6 +84,7 @@ class NeuralMemory(Module):
84
84
  def __init__(
85
85
  self,
86
86
  dim,
87
+ chunk_size = 1,
87
88
  model: Module | None = None,
88
89
  store_memory_loss_fn: Callable = default_loss_fn
89
90
  ):
@@ -98,11 +99,15 @@ class NeuralMemory(Module):
98
99
 
99
100
  self.memory_model = model
100
101
 
102
+ # the chunk size within the paper where adaptive step, momentum, weight decay are shared
103
+
104
+ self.chunk_size = chunk_size
105
+
101
106
  # prepare function for per sample gradients from model above, using torch.func
102
107
 
103
108
  def forward_and_loss(params, inputs, target):
104
109
  pred = functional_call(self.memory_model, params, inputs)
105
- loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) == v|²
110
+ loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
106
111
  return loss
107
112
 
108
113
  self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
@@ -119,9 +124,23 @@ class NeuralMemory(Module):
119
124
  # learned adaptive learning rate and momentum
120
125
  # todo - explore mlp layerwise learned lr / momentum
121
126
 
122
- self.to_momentum = LinearNoBias(dim, 1)
123
- self.to_adaptive_step = nn.Sequential(LinearNoBias(dim, 1), Rearrange('... 1 -> ...'))
124
- self.to_decay_factor = LinearNoBias(dim, 1) # weight decay factor
127
+ self.to_momentum = nn.Sequential(
128
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
129
+ LinearNoBias(dim, 1)
130
+ )
131
+
132
+ self.to_adaptive_step = nn.Sequential(
133
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
134
+ LinearNoBias(dim, 1),
135
+ Rearrange('... 1 -> ...')
136
+ )
137
+
138
+ # weight decay factor
139
+
140
+ self.to_decay_factor = nn.Sequential(
141
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
142
+ LinearNoBias(dim, 1)
143
+ )
125
144
 
126
145
  def init_weights_and_momentum(self):
127
146
  params = TensorDict(dict(self.memory_model.named_parameters()))
@@ -137,6 +156,16 @@ class NeuralMemory(Module):
137
156
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
138
157
  ):
139
158
 
159
+ # curtail sequence by multiple of the chunk size
160
+ # only a complete chunk of the sequence provides the memory for the next chunk
161
+
162
+ seq_len = seq.shape[-2]
163
+ round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
164
+
165
+ seq = seq[:, :round_down_seq_len]
166
+
167
+ # curr weights + past weights, in the case that the initial weights are learned
168
+
140
169
  curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
141
170
 
142
171
  past_state = tuple(TensorDict(d) for d in past_state)
@@ -148,16 +177,19 @@ class NeuralMemory(Module):
148
177
 
149
178
  batch = seq.shape[0]
150
179
 
151
- adaptive_lr = self.to_adaptive_step(seq).tanh() * 0.5 + 0.5.
180
+ adaptive_lr = self.to_adaptive_step(seq).tanh() * 0.5 + 0.5
152
181
 
153
182
  adaptive_momentum = self.to_momentum(seq).sigmoid()
154
183
  decay_factor = self.to_decay_factor(seq).sigmoid()
155
184
 
156
185
  # keys and values
157
186
 
158
- seq = rearrange(seq, 'b n d -> (b n) d')
159
187
  keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
160
188
 
189
+ # take care of chunking
190
+
191
+ keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
192
+
161
193
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
162
194
 
163
195
  grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
@@ -172,31 +204,24 @@ class NeuralMemory(Module):
172
204
 
173
205
  surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
174
206
 
175
- # derive momentum with associative scan - eq (10)
207
+ # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
176
208
 
177
209
  next_momentum = TensorDict()
210
+ updates = TensorDict()
178
211
 
179
212
  for param_name, surprise in surprises.items():
180
213
  surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
181
214
 
182
- _, momentum = associative_scan(binary_operator, (adaptive_momentum, surprise)) # momentum is S / surprise in the paper
183
-
184
- momentum = inverse_pack(momentum)
185
-
186
- next_momentum[param_name] = momentum
215
+ # derive momentum with associative scan - eq (10)
187
216
 
188
- # use associative scan again for learned forgetting (weight decay) - eq (13)
189
-
190
- updates = TensorDict()
217
+ _, momentum = associative_scan(binary_operator, (adaptive_momentum, surprise)) # momentum is S / surprise in the paper
191
218
 
192
- for param_name, momentum in next_momentum.items():
193
- momentum, inverse_pack = pack_one_with_inverse(momentum, 'b n *')
219
+ # use associative scan again for learned forgetting (weight decay) - eq (13)
194
220
 
195
221
  _, update = associative_scan(binary_operator, (1. - decay_factor, momentum)) # momentum is S / surprise in the paper
196
222
 
197
- update = inverse_pack(update)
198
-
199
- updates[param_name] = update
223
+ updates[param_name] = inverse_pack(update)
224
+ next_momentum[param_name] = inverse_pack(momentum)
200
225
 
201
226
  # compute the next weight per batch
202
227
 
File without changes
File without changes
File without changes