titans-pytorch 0.0.7__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -55,6 +55,10 @@ def pack_one_with_inverse(t, pattern):
55
55
 
56
56
  return packed, inverse
57
57
 
58
+ def softclamp_max(t, max_value):
59
+ range_value = max_value / 2
60
+ return ((t / range_value).tanh() * range_value) + range_value
61
+
58
62
  # classes
59
63
 
60
64
  class MLP(Module):
@@ -92,7 +96,8 @@ class NeuralMemory(Module):
92
96
  chunk_size = 1,
93
97
  model: Module | None = None,
94
98
  store_memory_loss_fn: Callable = default_loss_fn,
95
- pre_rmsnorm = False
99
+ pre_rmsnorm = False,
100
+ max_adaptive_step_size = 1e-5
96
101
  ):
97
102
  super().__init__()
98
103
 
@@ -144,6 +149,8 @@ class NeuralMemory(Module):
144
149
  Rearrange('... 1 -> ...')
145
150
  )
146
151
 
152
+ self.max_adaptive_step_size = max_adaptive_step_size
153
+
147
154
  # weight decay factor
148
155
 
149
156
  self.to_decay_factor = nn.Sequential(
@@ -188,7 +195,7 @@ class NeuralMemory(Module):
188
195
 
189
196
  batch = seq.shape[0]
190
197
 
191
- adaptive_lr = self.to_adaptive_step(seq).tanh() * 0.5 + 0.5
198
+ adaptive_lr = softclamp_max(self.to_adaptive_step(seq), self.max_adaptive_step_size)
192
199
 
193
200
  adaptive_momentum = self.to_momentum(seq).sigmoid()
194
201
  decay_factor = self.to_decay_factor(seq).sigmoid()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.7
3
+ Version: 0.0.8
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -39,7 +39,8 @@ Requires-Dist: einx>=0.3.0
39
39
  Requires-Dist: tensordict>=0.6.2
40
40
  Requires-Dist: torch>=2.3
41
41
  Provides-Extra: examples
42
- Requires-Dist: local-attention>=1.9.15; extra == 'examples'
42
+ Requires-Dist: local-attention>=1.10.0; extra == 'examples'
43
+ Requires-Dist: taylor-series-linear-attention; extra == 'examples'
43
44
  Provides-Extra: test
44
45
  Requires-Dist: pytest; extra == 'test'
45
46
  Description-Content-Type: text/markdown
@@ -64,16 +65,16 @@ $ pip install titans-pytorch
64
65
  import torch
65
66
  from titans_pytorch import NeuralMemory
66
67
 
67
- x = torch.randn(2, 64, 32)
68
-
69
68
  mem = NeuralMemory(
70
- dim = 32,
71
- chunk_size = 2
72
- )
69
+ dim = 384,
70
+ chunk_size = 64,
71
+ pre_rmsnorm = True
72
+ ).cuda()
73
73
 
74
- out = mem(x)
74
+ seq = torch.randn(2, 1024, 384).cuda()
75
+ retrieved = mem(seq)
75
76
 
76
- assert x.shape == out.shape
77
+ assert seq.shape == retrieved.shape
77
78
  ```
78
79
 
79
80
  ## Citations
@@ -0,0 +1,7 @@
1
+ titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=I8o_wauKQre87wl645Gy1nchPn5ckfZq4Q6pf5IpToU,9988
4
+ titans_pytorch-0.0.8.dist-info/METADATA,sha256=70ENsVn58zYyJM0UKanESujf0xSQWTcWXBh5HY1frKk,3219
5
+ titans_pytorch-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ titans_pytorch-0.0.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ titans_pytorch-0.0.8.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=IuAhvpeEGP_XQfhpWGEwNcmlSUKmtmtxWjjgwdEy0oI,9730
4
- titans_pytorch-0.0.7.dist-info/METADATA,sha256=NAIbruJrJLaafP5SjcM3a_6qe6yCgiohpC09CzOKsMg,3092
5
- titans_pytorch-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- titans_pytorch-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- titans_pytorch-0.0.7.dist-info/RECORD,,