titans-pytorch 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -55,6 +55,10 @@ def pack_one_with_inverse(t, pattern):
55
55
 
56
56
  return packed, inverse
57
57
 
58
+ def softclamp_max(t, max_value):
59
+ range_value = max_value / 2
60
+ return ((t / range_value).tanh() * range_value) + range_value
61
+
58
62
  # classes
59
63
 
60
64
  class MLP(Module):
@@ -91,10 +95,15 @@ class NeuralMemory(Module):
91
95
  dim,
92
96
  chunk_size = 1,
93
97
  model: Module | None = None,
94
- store_memory_loss_fn: Callable = default_loss_fn
98
+ store_memory_loss_fn: Callable = default_loss_fn,
99
+ pre_rmsnorm = False,
100
+ max_adaptive_step_size = 1e-5
95
101
  ):
96
102
  super().__init__()
97
103
 
104
+ self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
105
+ self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
106
+
98
107
  if not exists(model):
99
108
  model = MLP(dim, depth = 4)
100
109
 
@@ -140,6 +149,8 @@ class NeuralMemory(Module):
140
149
  Rearrange('... 1 -> ...')
141
150
  )
142
151
 
152
+ self.max_adaptive_step_size = max_adaptive_step_size
153
+
143
154
  # weight decay factor
144
155
 
145
156
  self.to_decay_factor = nn.Sequential(
@@ -161,6 +172,8 @@ class NeuralMemory(Module):
161
172
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
162
173
  ):
163
174
 
175
+ seq = self.store_norm(seq)
176
+
164
177
  # curtail sequence by multiple of the chunk size
165
178
  # only a complete chunk of the sequence provides the memory for the next chunk
166
179
 
@@ -182,7 +195,7 @@ class NeuralMemory(Module):
182
195
 
183
196
  batch = seq.shape[0]
184
197
 
185
- adaptive_lr = self.to_adaptive_step(seq).tanh() * 0.5 + 0.5
198
+ adaptive_lr = softclamp_max(self.to_adaptive_step(seq), self.max_adaptive_step_size)
186
199
 
187
200
  adaptive_momentum = self.to_momentum(seq).sigmoid()
188
201
  decay_factor = self.to_decay_factor(seq).sigmoid()
@@ -244,6 +257,8 @@ class NeuralMemory(Module):
244
257
  chunk_size = self.chunk_size
245
258
  batch, seq_len = seq.shape[:2]
246
259
 
260
+ seq = self.retrieve_norm(seq)
261
+
247
262
  assert seq_len >= chunk_size
248
263
 
249
264
  seq = seq[:, (chunk_size - 1):]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -39,7 +39,8 @@ Requires-Dist: einx>=0.3.0
39
39
  Requires-Dist: tensordict>=0.6.2
40
40
  Requires-Dist: torch>=2.3
41
41
  Provides-Extra: examples
42
- Requires-Dist: local-attention>=1.9.15; extra == 'examples'
42
+ Requires-Dist: local-attention>=1.10.0; extra == 'examples'
43
+ Requires-Dist: taylor-series-linear-attention; extra == 'examples'
43
44
  Provides-Extra: test
44
45
  Requires-Dist: pytest; extra == 'test'
45
46
  Description-Content-Type: text/markdown
@@ -64,16 +65,16 @@ $ pip install titans-pytorch
64
65
  import torch
65
66
  from titans_pytorch import NeuralMemory
66
67
 
67
- x = torch.randn(2, 64, 32)
68
-
69
68
  mem = NeuralMemory(
70
- dim = 32,
71
- chunk_size = 2
72
- )
69
+ dim = 384,
70
+ chunk_size = 64,
71
+ pre_rmsnorm = True
72
+ ).cuda()
73
73
 
74
- out = mem(x)
74
+ seq = torch.randn(2, 1024, 384).cuda()
75
+ retrieved = mem(seq)
75
76
 
76
- assert x.shape == out.shape
77
+ assert seq.shape == retrieved.shape
77
78
  ```
78
79
 
79
80
  ## Citations
@@ -0,0 +1,7 @@
1
+ titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=I8o_wauKQre87wl645Gy1nchPn5ckfZq4Q6pf5IpToU,9988
4
+ titans_pytorch-0.0.8.dist-info/METADATA,sha256=70ENsVn58zYyJM0UKanESujf0xSQWTcWXBh5HY1frKk,3219
5
+ titans_pytorch-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ titans_pytorch-0.0.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ titans_pytorch-0.0.8.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=S8J8B9o7Rlnj2hU3FZgpn28GTmis3ZbenLqjB_uny54,9470
4
- titans_pytorch-0.0.6.dist-info/METADATA,sha256=t4HXD6sZT7_pgcwD8TBY6ojYHUHiZ05J6t19wRKtHNc,3092
5
- titans_pytorch-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- titans_pytorch-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- titans_pytorch-0.0.6.dist-info/RECORD,,