titans-pytorch 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
titans_pytorch/titans.py CHANGED
@@ -55,6 +55,21 @@ def pack_one_with_inverse(t, pattern):
55
55
 
56
56
  return packed, inverse
57
57
 
58
+ # softclamping gradients
59
+
60
+ def softclamp_max(t, max_value):
61
+ half_max_value = max_value / 2
62
+ return ((t / half_max_value).tanh() * half_max_value) + half_max_value
63
+
64
+ def softclamp_grad_norm(t, max_value):
65
+ t, inverse = pack_one_with_inverse(t, 'bn *')
66
+
67
+ norm = t.norm(dim = -1, keepdim = True)
68
+ clamped_norm = softclamp_max(norm, max_value)
69
+
70
+ t = t * (clamped_norm / norm)
71
+ return inverse(t)
72
+
58
73
  # classes
59
74
 
60
75
  class MemoryMLP(Module):
@@ -96,6 +111,7 @@ class NeuralMemory(Module):
96
111
  store_memory_loss_fn: Callable = default_loss_fn,
97
112
  pre_rmsnorm = True,
98
113
  post_rmsnorm = True,
114
+ max_grad_norm: float | None = None,
99
115
  use_accelerated_scan = False,
100
116
  default_mlp_kwargs: dict = dict(
101
117
  depth = 4
@@ -172,6 +188,10 @@ class NeuralMemory(Module):
172
188
  Rearrange('b n h -> (b h) n')
173
189
  )
174
190
 
191
+ # allow for softclamp the gradient norms for storing memories
192
+
193
+ self.max_grad_norm = max_grad_norm
194
+
175
195
  # weight decay factor
176
196
 
177
197
  self.to_decay_factor = nn.Sequential(
@@ -247,6 +267,11 @@ class NeuralMemory(Module):
247
267
 
248
268
  grads = TensorDict(grads)
249
269
 
270
+ # maybe softclamp grad norm
271
+
272
+ if exists(self.max_grad_norm):
273
+ grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))
274
+
250
275
  # restore batch and sequence dimension
251
276
 
252
277
  grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.19
3
+ Version: 0.0.20
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=HpVjFy6jyzLGB_ilqjcYWGE-VtYmUrUwkXzmzqPrCXc,13370
4
+ titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
+ titans_pytorch-0.0.20.dist-info/METADATA,sha256=9qJWG-hwJ8IK9auhQV2XyEs54T0-LMvBAArF-iQ21IE,3811
6
+ titans_pytorch-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.20.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=CxbJgNdIS9NbbCDdgotFXAnrV16xmvufUErerKe7qJA,12636
4
- titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
- titans_pytorch-0.0.19.dist-info/METADATA,sha256=5Wpk79HYI4z8LeNRV__UaamKppiGcJ2HdIlll1JSZr8,3811
6
- titans_pytorch-0.0.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.19.dist-info/RECORD,,