titans-pytorch 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,8 @@ from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
4
  MemoryAttention,
5
- FactorizedMemoryMLP
5
+ FactorizedMemoryMLP,
6
+ GatedResidualMemoryMLP
6
7
  )
7
8
 
8
9
  from titans_pytorch.mac_transformer import (
titans_pytorch/titans.py CHANGED
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
  from typing import Callable
3
+
3
4
  import math
4
5
  from functools import partial
5
6
 
6
7
  import torch
7
- from torch import nn, Tensor
8
+ from torch import nn, cat, Tensor
8
9
  import torch.nn.functional as F
9
10
  from torch.nn import Linear, Module, Parameter, ParameterList
10
11
  from torch.func import functional_call, vmap, grad
@@ -154,6 +155,46 @@ class MemoryMLP(Module):
154
155
 
155
156
  return x
156
157
 
158
+ # memory mlp, but with gated residual + final projection
159
+
160
+ class GatedResidualMemoryMLP(Module):
161
+ def __init__(
162
+ self,
163
+ dim,
164
+ depth
165
+ ):
166
+ super().__init__()
167
+ self.depth = depth
168
+
169
+ self.weights = ParameterList([
170
+ ParameterList([
171
+ Parameter(torch.randn(dim, dim)),
172
+ Parameter(torch.randn(dim * 2, dim)),
173
+ ]) for _ in range(depth)
174
+ ])
175
+
176
+ self.final_proj = Parameter(torch.randn(dim, dim))
177
+
178
+ for param in self.parameters():
179
+ nn.init.xavier_uniform_(param)
180
+
181
+ def forward(
182
+ self,
183
+ x
184
+ ):
185
+ for weight, to_gates in self.weights:
186
+ res = x
187
+
188
+ x = x @ weight
189
+ x = F.silu(x)
190
+
191
+ # gated residual
192
+
193
+ gates = cat((x, res), dim = -1) @ to_gates
194
+ x = res.lerp(x, gates.sigmoid())
195
+
196
+ return x @ self.final_proj
197
+
157
198
  # memory mlp with factorized weights
158
199
  # so can tradeoff capacity for smaller chunk sizes
159
200
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -46,6 +46,7 @@ Requires-Dist: torch>=2.2
46
46
  Requires-Dist: tqdm
47
47
  Requires-Dist: x-transformers
48
48
  Provides-Extra: examples
49
+ Requires-Dist: adam-atan2-pytorch>=0.1.18; extra == 'examples'
49
50
  Requires-Dist: wandb; extra == 'examples'
50
51
  Provides-Extra: test
51
52
  Requires-Dist: pytest; extra == 'test'
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
4
+ titans_pytorch/titans.py,sha256=VMcPcKsoR3G13Um62Aa1HbdwrrV60ljPhP-yF40x90I,20555
5
+ titans_pytorch-0.1.6.dist-info/METADATA,sha256=LJW26WfT9WB-0NfokLLHhcRpWnt76jwkXMt_FSTI3SM,4747
6
+ titans_pytorch-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.6.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
4
- titans_pytorch/titans.py,sha256=iF0tTTyLs3hPhJDvGVKD2PdXgpWo9xOggD_42szPwjg,19632
5
- titans_pytorch-0.1.5.dist-info/METADATA,sha256=GrCMbvIDT9gdL8JJ-U55oxFeB8TVRI2PTuvFK2QQjbk,4684
6
- titans_pytorch-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.5.dist-info/RECORD,,