titans-pytorch 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -55,6 +55,21 @@ def pack_one_with_inverse(t, pattern):
55
55
 
56
56
  return packed, inverse
57
57
 
58
+ # softclamping gradients
59
+
60
+ def softclamp_max(t, max_value):
61
+ half_max_value = max_value / 2
62
+ return ((t / half_max_value).tanh() * half_max_value) + half_max_value
63
+
64
+ def softclamp_grad_norm(t, max_value):
65
+ t, inverse = pack_one_with_inverse(t, 'bn *')
66
+
67
+ norm = t.norm(dim = -1, keepdim = True)
68
+ clamped_norm = softclamp_max(norm, max_value)
69
+
70
+ t = t * (clamped_norm / norm)
71
+ return inverse(t)
72
+
58
73
  # classes
59
74
 
60
75
  class MemoryMLP(Module):
@@ -96,6 +111,7 @@ class NeuralMemory(Module):
96
111
  store_memory_loss_fn: Callable = default_loss_fn,
97
112
  pre_rmsnorm = True,
98
113
  post_rmsnorm = True,
114
+ max_grad_norm: float | None = None,
99
115
  use_accelerated_scan = False,
100
116
  default_mlp_kwargs: dict = dict(
101
117
  depth = 4
@@ -172,6 +188,10 @@ class NeuralMemory(Module):
172
188
  Rearrange('b n h -> (b h) n')
173
189
  )
174
190
 
191
+ # allow for softclamp the gradient norms for storing memories
192
+
193
+ self.max_grad_norm = max_grad_norm
194
+
175
195
  # weight decay factor
176
196
 
177
197
  self.to_decay_factor = nn.Sequential(
@@ -247,6 +267,11 @@ class NeuralMemory(Module):
247
267
 
248
268
  grads = TensorDict(grads)
249
269
 
270
+ # maybe softclamp grad norm
271
+
272
+ if exists(self.max_grad_norm):
273
+ grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))
274
+
250
275
  # restore batch and sequence dimension
251
276
 
252
277
  grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.19
3
+ Version: 0.0.20
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=HpVjFy6jyzLGB_ilqjcYWGE-VtYmUrUwkXzmzqPrCXc,13370
4
+ titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
+ titans_pytorch-0.0.20.dist-info/METADATA,sha256=9qJWG-hwJ8IK9auhQV2XyEs54T0-LMvBAArF-iQ21IE,3811
6
+ titans_pytorch-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.20.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=CxbJgNdIS9NbbCDdgotFXAnrV16xmvufUErerKe7qJA,12636
4
- titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
- titans_pytorch-0.0.19.dist-info/METADATA,sha256=5Wpk79HYI4z8LeNRV__UaamKppiGcJ2HdIlll1JSZr8,3811
6
- titans_pytorch-0.0.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.19.dist-info/RECORD,,