titans-pytorch 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -40,6 +40,9 @@ def exists(v):
40
40
  def default(v, d):
41
41
  return v if exists(v) else d
42
42
 
43
+ def identity(t):
44
+ return t
45
+
43
46
  def round_down_multiple(seq, mult):
44
47
  return seq // mult * mult
45
48
 
@@ -55,6 +58,21 @@ def pack_one_with_inverse(t, pattern):
55
58
 
56
59
  return packed, inverse
57
60
 
61
+ # softclamping gradients
62
+
63
+ def softclamp_max(t, max_value):
64
+ half_max_value = max_value / 2
65
+ return ((t / half_max_value).tanh() * half_max_value) + half_max_value
66
+
67
+ def softclamp_grad_norm(t, max_value):
68
+ t, inverse = pack_one_with_inverse(t, 'bn *')
69
+
70
+ norm = t.norm(dim = -1, keepdim = True)
71
+ clamped_norm = softclamp_max(norm, max_value)
72
+
73
+ t = t * (clamped_norm / norm)
74
+ return inverse(t)
75
+
58
76
  # classes
59
77
 
60
78
  class MemoryMLP(Module):
@@ -82,6 +100,9 @@ class MemoryMLP(Module):
82
100
 
83
101
  # main neural memory
84
102
 
103
+ def default_adaptive_step_transform(adaptive_step):
104
+ return torch.exp(adaptive_step.sigmoid() * -15) # from 1. - 1e-7
105
+
85
106
  def default_loss_fn(pred, target):
86
107
  return (pred - target).pow(2).mean(dim = -1).sum()
87
108
 
@@ -94,8 +115,10 @@ class NeuralMemory(Module):
94
115
  heads = 1,
95
116
  model: Module | None = None,
96
117
  store_memory_loss_fn: Callable = default_loss_fn,
118
+ adaptive_step_transform: Callable = default_adaptive_step_transform,
97
119
  pre_rmsnorm = True,
98
120
  post_rmsnorm = True,
121
+ max_grad_norm: float | None = None,
99
122
  use_accelerated_scan = False,
100
123
  default_mlp_kwargs: dict = dict(
101
124
  depth = 4
@@ -172,6 +195,12 @@ class NeuralMemory(Module):
172
195
  Rearrange('b n h -> (b h) n')
173
196
  )
174
197
 
198
+ self.adaptive_step_transform = adaptive_step_transform
199
+
200
+ # allow for softclamp the gradient norms for storing memories
201
+
202
+ self.max_grad_norm = max_grad_norm
203
+
175
204
  # weight decay factor
176
205
 
177
206
  self.to_decay_factor = nn.Sequential(
@@ -222,7 +251,8 @@ class NeuralMemory(Module):
222
251
 
223
252
  # pack batch and sequence dimension
224
253
 
225
- adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
254
+ adaptive_lr = self.to_adaptive_step(seq)
255
+ adaptive_lr = self.adaptive_step_transform(adaptive_lr)
226
256
 
227
257
  adaptive_momentum = self.to_momentum(seq).sigmoid()
228
258
  decay_factor = self.to_decay_factor(seq).sigmoid()
@@ -247,6 +277,11 @@ class NeuralMemory(Module):
247
277
 
248
278
  grads = TensorDict(grads)
249
279
 
280
+ # maybe softclamp grad norm
281
+
282
+ if exists(self.max_grad_norm):
283
+ grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))
284
+
250
285
  # restore batch and sequence dimension
251
286
 
252
287
  grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.19
3
+ Version: 0.0.21
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=rsgpGs-nnqvsvF6Mu2GfmfNlRgYfHAV9MCEmd1ohUUI,13687
4
+ titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
+ titans_pytorch-0.0.21.dist-info/METADATA,sha256=3z2LGviHUicQq4_ThlRasOUCDXIe0xX2aIhxtKLWI0Q,3811
6
+ titans_pytorch-0.0.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.21.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=CxbJgNdIS9NbbCDdgotFXAnrV16xmvufUErerKe7qJA,12636
4
- titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
- titans_pytorch-0.0.19.dist-info/METADATA,sha256=5Wpk79HYI4z8LeNRV__UaamKppiGcJ2HdIlll1JSZr8,3811
6
- titans_pytorch-0.0.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.19.dist-info/RECORD,,