titans-pytorch 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
titans_pytorch/titans.py CHANGED
@@ -40,6 +40,9 @@ def exists(v):
40
40
  def default(v, d):
41
41
  return v if exists(v) else d
42
42
 
43
+ def identity(t):
44
+ return t
45
+
43
46
  def round_down_multiple(seq, mult):
44
47
  return seq // mult * mult
45
48
 
@@ -55,6 +58,21 @@ def pack_one_with_inverse(t, pattern):
55
58
 
56
59
  return packed, inverse
57
60
 
61
+ # softclamping gradients
62
+
63
+ def softclamp_max(t, max_value):
64
+ half_max_value = max_value / 2
65
+ return ((t / half_max_value).tanh() * half_max_value) + half_max_value
66
+
67
+ def softclamp_grad_norm(t, max_value):
68
+ t, inverse = pack_one_with_inverse(t, 'bn *')
69
+
70
+ norm = t.norm(dim = -1, keepdim = True)
71
+ clamped_norm = softclamp_max(norm, max_value)
72
+
73
+ t = t * (clamped_norm / norm)
74
+ return inverse(t)
75
+
58
76
  # classes
59
77
 
60
78
  class MemoryMLP(Module):
@@ -82,6 +100,9 @@ class MemoryMLP(Module):
82
100
 
83
101
  # main neural memory
84
102
 
103
+ def default_adaptive_step_transform(adaptive_step):
104
+ return torch.exp(adaptive_step.sigmoid() * -15) # from 1. - 1e-7
105
+
85
106
  def default_loss_fn(pred, target):
86
107
  return (pred - target).pow(2).mean(dim = -1).sum()
87
108
 
@@ -94,8 +115,10 @@ class NeuralMemory(Module):
94
115
  heads = 1,
95
116
  model: Module | None = None,
96
117
  store_memory_loss_fn: Callable = default_loss_fn,
118
+ adaptive_step_transform: Callable = default_adaptive_step_transform,
97
119
  pre_rmsnorm = True,
98
120
  post_rmsnorm = True,
121
+ max_grad_norm: float | None = None,
99
122
  use_accelerated_scan = False,
100
123
  default_mlp_kwargs: dict = dict(
101
124
  depth = 4
@@ -172,6 +195,12 @@ class NeuralMemory(Module):
172
195
  Rearrange('b n h -> (b h) n')
173
196
  )
174
197
 
198
+ self.adaptive_step_transform = adaptive_step_transform
199
+
200
+ # allow for softclamp the gradient norms for storing memories
201
+
202
+ self.max_grad_norm = max_grad_norm
203
+
175
204
  # weight decay factor
176
205
 
177
206
  self.to_decay_factor = nn.Sequential(
@@ -222,7 +251,8 @@ class NeuralMemory(Module):
222
251
 
223
252
  # pack batch and sequence dimension
224
253
 
225
- adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
254
+ adaptive_lr = self.to_adaptive_step(seq)
255
+ adaptive_lr = self.adaptive_step_transform(adaptive_lr)
226
256
 
227
257
  adaptive_momentum = self.to_momentum(seq).sigmoid()
228
258
  decay_factor = self.to_decay_factor(seq).sigmoid()
@@ -247,6 +277,11 @@ class NeuralMemory(Module):
247
277
 
248
278
  grads = TensorDict(grads)
249
279
 
280
+ # maybe softclamp grad norm
281
+
282
+ if exists(self.max_grad_norm):
283
+ grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))
284
+
250
285
  # restore batch and sequence dimension
251
286
 
252
287
  grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.19
3
+ Version: 0.0.21
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=rsgpGs-nnqvsvF6Mu2GfmfNlRgYfHAV9MCEmd1ohUUI,13687
4
+ titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
+ titans_pytorch-0.0.21.dist-info/METADATA,sha256=3z2LGviHUicQq4_ThlRasOUCDXIe0xX2aIhxtKLWI0Q,3811
6
+ titans_pytorch-0.0.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.21.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=CxbJgNdIS9NbbCDdgotFXAnrV16xmvufUErerKe7qJA,12636
4
- titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
- titans_pytorch-0.0.19.dist-info/METADATA,sha256=5Wpk79HYI4z8LeNRV__UaamKppiGcJ2HdIlll1JSZr8,3811
6
- titans_pytorch-0.0.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.19.dist-info/RECORD,,