titans-pytorch 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
titans_pytorch/titans.py CHANGED
@@ -17,7 +17,7 @@ from titans_pytorch.associative_scan import (
17
17
  )
18
18
 
19
19
  import einx
20
- from einops import rearrange, pack, unpack
20
+ from einops import rearrange, repeat, pack, unpack
21
21
  from einops.layers.torch import Rearrange, Reduce
22
22
 
23
23
  """
@@ -55,6 +55,21 @@ def pack_one_with_inverse(t, pattern):
55
55
 
56
56
  return packed, inverse
57
57
 
58
+ # softclamping gradients
59
+
60
+ def softclamp_max(t, max_value):
61
+ half_max_value = max_value / 2
62
+ return ((t / half_max_value).tanh() * half_max_value) + half_max_value
63
+
64
+ def softclamp_grad_norm(t, max_value):
65
+ t, inverse = pack_one_with_inverse(t, 'bn *')
66
+
67
+ norm = t.norm(dim = -1, keepdim = True)
68
+ clamped_norm = softclamp_max(norm, max_value)
69
+
70
+ t = t * (clamped_norm / norm)
71
+ return inverse(t)
72
+
58
73
  # classes
59
74
 
60
75
  class MemoryMLP(Module):
@@ -96,6 +111,7 @@ class NeuralMemory(Module):
96
111
  store_memory_loss_fn: Callable = default_loss_fn,
97
112
  pre_rmsnorm = True,
98
113
  post_rmsnorm = True,
114
+ max_grad_norm: float | None = None,
99
115
  use_accelerated_scan = False,
100
116
  default_mlp_kwargs: dict = dict(
101
117
  depth = 4
@@ -152,6 +168,11 @@ class NeuralMemory(Module):
152
168
  self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
153
169
  self.store_memory_loss_fn = store_memory_loss_fn
154
170
 
171
+ # empty memory embed
172
+
173
+ self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
174
+ nn.init.normal_(self.empty_memory_embed, std = 0.02)
175
+
155
176
  # learned adaptive learning rate and momentum
156
177
  # todo - explore mlp layerwise learned lr / momentum
157
178
 
@@ -167,6 +188,10 @@ class NeuralMemory(Module):
167
188
  Rearrange('b n h -> (b h) n')
168
189
  )
169
190
 
191
+ # allow for softclamp the gradient norms for storing memories
192
+
193
+ self.max_grad_norm = max_grad_norm
194
+
170
195
  # weight decay factor
171
196
 
172
197
  self.to_decay_factor = nn.Sequential(
@@ -187,6 +212,9 @@ class NeuralMemory(Module):
187
212
 
188
213
  return init_weights, init_momentum
189
214
 
215
+ def init_empty_memory_embed(self, batch, seq_len):
216
+ return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
217
+
190
218
  def store_memories(
191
219
  self,
192
220
  seq,
@@ -239,6 +267,11 @@ class NeuralMemory(Module):
239
267
 
240
268
  grads = TensorDict(grads)
241
269
 
270
+ # maybe softclamp grad norm
271
+
272
+ if exists(self.max_grad_norm):
273
+ grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))
274
+
242
275
  # restore batch and sequence dimension
243
276
 
244
277
  grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
@@ -372,11 +405,12 @@ class NeuralMemory(Module):
372
405
 
373
406
  values = self.post_rmsnorm(values)
374
407
 
375
- # restore
408
+ # restore, pad with empty memory embed
376
409
 
377
- values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
378
- values = values[:, :-padding]
410
+ empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
411
+ values = torch.cat((empty_memory_embeds, values), dim = -2)
379
412
 
413
+ values = values[:, :-padding]
380
414
  return values
381
415
 
382
416
  def forward(
@@ -389,7 +423,7 @@ class NeuralMemory(Module):
389
423
  batch, seq_len = seq.shape[:2]
390
424
 
391
425
  if seq_len < self.chunk_size:
392
- return torch.zeros_like(seq)
426
+ return self.init_empty_memory_embed(batch, seq_len)
393
427
 
394
428
  if exists(past_state):
395
429
  past_state = tuple(TensorDict(d) for d in past_state)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.18
3
+ Version: 0.0.20
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=HpVjFy6jyzLGB_ilqjcYWGE-VtYmUrUwkXzmzqPrCXc,13370
4
+ titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
+ titans_pytorch-0.0.20.dist-info/METADATA,sha256=9qJWG-hwJ8IK9auhQV2XyEs54T0-LMvBAArF-iQ21IE,3811
6
+ titans_pytorch-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.20.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=-Xv3ufD2vhprNFliuu1lGx27nx7AvHi6yFG2g9eHaqY,12295
4
- titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
- titans_pytorch-0.0.18.dist-info/METADATA,sha256=YX0EPMqVioQjAVxoI3CAKV8nWgwZZ0tw4djgud4bEqs,3811
6
- titans_pytorch-0.0.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.18.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.18.dist-info/RECORD,,