titans-pytorch 0.0.18__tar.gz → 0.0.20__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.18
3
+ Version: 0.0.20
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.18"
3
+ version = "0.0.20"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,7 +1,11 @@
1
1
  import torch
2
2
  import pytest
3
3
 
4
- def test_titans():
4
+ @pytest.mark.parametrize('seq_len', (32, 1024, 77))
5
+ def test_titans(
6
+ seq_len
7
+ ):
8
+
5
9
  from titans_pytorch import NeuralMemory
6
10
 
7
11
  mem = NeuralMemory(
@@ -9,7 +13,7 @@ def test_titans():
9
13
  chunk_size = 64,
10
14
  )
11
15
 
12
- seq = torch.randn(2, 1024, 384)
16
+ seq = torch.randn(2, seq_len, 384)
13
17
  retrieved = mem(seq)
14
18
 
15
19
  assert seq.shape == retrieved.shape
@@ -17,7 +17,7 @@ from titans_pytorch.associative_scan import (
17
17
  )
18
18
 
19
19
  import einx
20
- from einops import rearrange, pack, unpack
20
+ from einops import rearrange, repeat, pack, unpack
21
21
  from einops.layers.torch import Rearrange, Reduce
22
22
 
23
23
  """
@@ -55,6 +55,21 @@ def pack_one_with_inverse(t, pattern):
55
55
 
56
56
  return packed, inverse
57
57
 
58
+ # softclamping gradients
59
+
60
+ def softclamp_max(t, max_value):
61
+ half_max_value = max_value / 2
62
+ return ((t / half_max_value).tanh() * half_max_value) + half_max_value
63
+
64
+ def softclamp_grad_norm(t, max_value):
65
+ t, inverse = pack_one_with_inverse(t, 'bn *')
66
+
67
+ norm = t.norm(dim = -1, keepdim = True)
68
+ clamped_norm = softclamp_max(norm, max_value)
69
+
70
+ t = t * (clamped_norm / norm)
71
+ return inverse(t)
72
+
58
73
  # classes
59
74
 
60
75
  class MemoryMLP(Module):
@@ -96,6 +111,7 @@ class NeuralMemory(Module):
96
111
  store_memory_loss_fn: Callable = default_loss_fn,
97
112
  pre_rmsnorm = True,
98
113
  post_rmsnorm = True,
114
+ max_grad_norm: float | None = None,
99
115
  use_accelerated_scan = False,
100
116
  default_mlp_kwargs: dict = dict(
101
117
  depth = 4
@@ -152,6 +168,11 @@ class NeuralMemory(Module):
152
168
  self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
153
169
  self.store_memory_loss_fn = store_memory_loss_fn
154
170
 
171
+ # empty memory embed
172
+
173
+ self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
174
+ nn.init.normal_(self.empty_memory_embed, std = 0.02)
175
+
155
176
  # learned adaptive learning rate and momentum
156
177
  # todo - explore mlp layerwise learned lr / momentum
157
178
 
@@ -167,6 +188,10 @@ class NeuralMemory(Module):
167
188
  Rearrange('b n h -> (b h) n')
168
189
  )
169
190
 
191
+ # allow for softclamp the gradient norms for storing memories
192
+
193
+ self.max_grad_norm = max_grad_norm
194
+
170
195
  # weight decay factor
171
196
 
172
197
  self.to_decay_factor = nn.Sequential(
@@ -187,6 +212,9 @@ class NeuralMemory(Module):
187
212
 
188
213
  return init_weights, init_momentum
189
214
 
215
+ def init_empty_memory_embed(self, batch, seq_len):
216
+ return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
217
+
190
218
  def store_memories(
191
219
  self,
192
220
  seq,
@@ -239,6 +267,11 @@ class NeuralMemory(Module):
239
267
 
240
268
  grads = TensorDict(grads)
241
269
 
270
+ # maybe softclamp grad norm
271
+
272
+ if exists(self.max_grad_norm):
273
+ grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))
274
+
242
275
  # restore batch and sequence dimension
243
276
 
244
277
  grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
@@ -372,11 +405,12 @@ class NeuralMemory(Module):
372
405
 
373
406
  values = self.post_rmsnorm(values)
374
407
 
375
- # restore
408
+ # restore, pad with empty memory embed
376
409
 
377
- values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
378
- values = values[:, :-padding]
410
+ empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
411
+ values = torch.cat((empty_memory_embeds, values), dim = -2)
379
412
 
413
+ values = values[:, :-padding]
380
414
  return values
381
415
 
382
416
  def forward(
@@ -389,7 +423,7 @@ class NeuralMemory(Module):
389
423
  batch, seq_len = seq.shape[:2]
390
424
 
391
425
  if seq_len < self.chunk_size:
392
- return torch.zeros_like(seq)
426
+ return self.init_empty_memory_embed(batch, seq_len)
393
427
 
394
428
  if exists(past_state):
395
429
  past_state = tuple(TensorDict(d) for d in past_state)
File without changes