titans-pytorch 0.0.18__tar.gz → 0.0.20__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.18
3
+ Version: 0.0.20
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.18"
3
+ version = "0.0.20"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,7 +1,11 @@
1
1
  import torch
2
2
  import pytest
3
3
 
4
- def test_titans():
4
+ @pytest.mark.parametrize('seq_len', (32, 1024, 77))
5
+ def test_titans(
6
+ seq_len
7
+ ):
8
+
5
9
  from titans_pytorch import NeuralMemory
6
10
 
7
11
  mem = NeuralMemory(
@@ -9,7 +13,7 @@ def test_titans():
9
13
  chunk_size = 64,
10
14
  )
11
15
 
12
- seq = torch.randn(2, 1024, 384)
16
+ seq = torch.randn(2, seq_len, 384)
13
17
  retrieved = mem(seq)
14
18
 
15
19
  assert seq.shape == retrieved.shape
@@ -17,7 +17,7 @@ from titans_pytorch.associative_scan import (
17
17
  )
18
18
 
19
19
  import einx
20
- from einops import rearrange, pack, unpack
20
+ from einops import rearrange, repeat, pack, unpack
21
21
  from einops.layers.torch import Rearrange, Reduce
22
22
 
23
23
  """
@@ -55,6 +55,21 @@ def pack_one_with_inverse(t, pattern):
55
55
 
56
56
  return packed, inverse
57
57
 
58
+ # softclamping gradients
59
+
60
+ def softclamp_max(t, max_value):
61
+ half_max_value = max_value / 2
62
+ return ((t / half_max_value).tanh() * half_max_value) + half_max_value
63
+
64
+ def softclamp_grad_norm(t, max_value):
65
+ t, inverse = pack_one_with_inverse(t, 'bn *')
66
+
67
+ norm = t.norm(dim = -1, keepdim = True)
68
+ clamped_norm = softclamp_max(norm, max_value)
69
+
70
+ t = t * (clamped_norm / norm)
71
+ return inverse(t)
72
+
58
73
  # classes
59
74
 
60
75
  class MemoryMLP(Module):
@@ -96,6 +111,7 @@ class NeuralMemory(Module):
96
111
  store_memory_loss_fn: Callable = default_loss_fn,
97
112
  pre_rmsnorm = True,
98
113
  post_rmsnorm = True,
114
+ max_grad_norm: float | None = None,
99
115
  use_accelerated_scan = False,
100
116
  default_mlp_kwargs: dict = dict(
101
117
  depth = 4
@@ -152,6 +168,11 @@ class NeuralMemory(Module):
152
168
  self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
153
169
  self.store_memory_loss_fn = store_memory_loss_fn
154
170
 
171
+ # empty memory embed
172
+
173
+ self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
174
+ nn.init.normal_(self.empty_memory_embed, std = 0.02)
175
+
155
176
  # learned adaptive learning rate and momentum
156
177
  # todo - explore mlp layerwise learned lr / momentum
157
178
 
@@ -167,6 +188,10 @@ class NeuralMemory(Module):
167
188
  Rearrange('b n h -> (b h) n')
168
189
  )
169
190
 
191
+ # allow for softclamp the gradient norms for storing memories
192
+
193
+ self.max_grad_norm = max_grad_norm
194
+
170
195
  # weight decay factor
171
196
 
172
197
  self.to_decay_factor = nn.Sequential(
@@ -187,6 +212,9 @@ class NeuralMemory(Module):
187
212
 
188
213
  return init_weights, init_momentum
189
214
 
215
+ def init_empty_memory_embed(self, batch, seq_len):
216
+ return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
217
+
190
218
  def store_memories(
191
219
  self,
192
220
  seq,
@@ -239,6 +267,11 @@ class NeuralMemory(Module):
239
267
 
240
268
  grads = TensorDict(grads)
241
269
 
270
+ # maybe softclamp grad norm
271
+
272
+ if exists(self.max_grad_norm):
273
+ grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))
274
+
242
275
  # restore batch and sequence dimension
243
276
 
244
277
  grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
@@ -372,11 +405,12 @@ class NeuralMemory(Module):
372
405
 
373
406
  values = self.post_rmsnorm(values)
374
407
 
375
- # restore
408
+ # restore, pad with empty memory embed
376
409
 
377
- values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
378
- values = values[:, :-padding]
410
+ empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
411
+ values = torch.cat((empty_memory_embeds, values), dim = -2)
379
412
 
413
+ values = values[:, :-padding]
380
414
  return values
381
415
 
382
416
  def forward(
@@ -389,7 +423,7 @@ class NeuralMemory(Module):
389
423
  batch, seq_len = seq.shape[:2]
390
424
 
391
425
  if seq_len < self.chunk_size:
392
- return torch.zeros_like(seq)
426
+ return self.init_empty_memory_embed(batch, seq_len)
393
427
 
394
428
  if exists(past_state):
395
429
  past_state = tuple(TensorDict(d) for d in past_state)
File without changes