titans-pytorch 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -40,6 +40,9 @@ def exists(v):
40
40
  def default(v, d):
41
41
  return v if exists(v) else d
42
42
 
43
+ def identity(t):
44
+ return t
45
+
43
46
  def round_down_multiple(seq, mult):
44
47
  return seq // mult * mult
45
48
 
@@ -97,6 +100,9 @@ class MemoryMLP(Module):
97
100
 
98
101
  # main neural memory
99
102
 
103
+ def default_adaptive_step_transform(adaptive_step):
104
+ return torch.exp(adaptive_step.sigmoid() * -15) # from 1. - 1e-7
105
+
100
106
  def default_loss_fn(pred, target):
101
107
  return (pred - target).pow(2).mean(dim = -1).sum()
102
108
 
@@ -109,6 +115,7 @@ class NeuralMemory(Module):
109
115
  heads = 1,
110
116
  model: Module | None = None,
111
117
  store_memory_loss_fn: Callable = default_loss_fn,
118
+ adaptive_step_transform: Callable = default_adaptive_step_transform,
112
119
  pre_rmsnorm = True,
113
120
  post_rmsnorm = True,
114
121
  max_grad_norm: float | None = None,
@@ -188,6 +195,8 @@ class NeuralMemory(Module):
188
195
  Rearrange('b n h -> (b h) n')
189
196
  )
190
197
 
198
+ self.adaptive_step_transform = adaptive_step_transform
199
+
191
200
  # allow for softclamp the gradient norms for storing memories
192
201
 
193
202
  self.max_grad_norm = max_grad_norm
@@ -242,7 +251,8 @@ class NeuralMemory(Module):
242
251
 
243
252
  # pack batch and sequence dimension
244
253
 
245
- adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
254
+ adaptive_lr = self.to_adaptive_step(seq)
255
+ adaptive_lr = self.adaptive_step_transform(adaptive_lr)
246
256
 
247
257
  adaptive_momentum = self.to_momentum(seq).sigmoid()
248
258
  decay_factor = self.to_decay_factor(seq).sigmoid()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.20
3
+ Version: 0.0.21
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=rsgpGs-nnqvsvF6Mu2GfmfNlRgYfHAV9MCEmd1ohUUI,13687
4
+ titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
+ titans_pytorch-0.0.21.dist-info/METADATA,sha256=3z2LGviHUicQq4_ThlRasOUCDXIe0xX2aIhxtKLWI0Q,3811
6
+ titans_pytorch-0.0.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.21.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=HpVjFy6jyzLGB_ilqjcYWGE-VtYmUrUwkXzmzqPrCXc,13370
4
- titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
- titans_pytorch-0.0.20.dist-info/METADATA,sha256=9qJWG-hwJ8IK9auhQV2XyEs54T0-LMvBAArF-iQ21IE,3811
6
- titans_pytorch-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.20.dist-info/RECORD,,