titans-pytorch 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,8 @@ from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
4
  MemoryAttention,
5
- FactorizedMemoryMLP
5
+ FactorizedMemoryMLP,
6
+ GatedResidualMemoryMLP
6
7
  )
7
8
 
8
9
  from titans_pytorch.mac_transformer import (
titans_pytorch/titans.py CHANGED
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
  from typing import Callable
3
+
3
4
  import math
4
5
  from functools import partial
5
6
 
6
7
  import torch
7
- from torch import nn, Tensor
8
+ from torch import nn, cat, Tensor
8
9
  import torch.nn.functional as F
9
10
  from torch.nn import Linear, Module, Parameter, ParameterList
10
11
  from torch.func import functional_call, vmap, grad
@@ -154,6 +155,49 @@ class MemoryMLP(Module):
154
155
 
155
156
  return x
156
157
 
158
+ # memory mlp, but with gated residual + final projection
159
+
160
+ class GatedResidualMemoryMLP(Module):
161
+ def __init__(
162
+ self,
163
+ dim,
164
+ depth,
165
+ expansion_factor = 2.
166
+ ):
167
+ super().__init__()
168
+ dim_hidden = int(dim * expansion_factor)
169
+
170
+ self.weights = ParameterList([
171
+ ParameterList([
172
+ Parameter(torch.randn(dim, dim_hidden)),
173
+ Parameter(torch.randn(dim_hidden, dim)),
174
+ Parameter(torch.randn(dim * 2, dim)),
175
+ ]) for _ in range(depth)
176
+ ])
177
+
178
+ self.final_proj = Parameter(torch.randn(dim, dim))
179
+
180
+ for param in self.parameters():
181
+ nn.init.xavier_uniform_(param)
182
+
183
+ def forward(
184
+ self,
185
+ x
186
+ ):
187
+ for weight1, weight2, to_gates in self.weights:
188
+ res = x
189
+
190
+ hidden = x @ weight1
191
+ hidden = F.silu(hidden)
192
+ branch_out = hidden @ weight2
193
+
194
+ # gated residual
195
+
196
+ gates = cat((branch_out, res), dim = -1) @ to_gates
197
+ x = res.lerp(branch_out, gates.sigmoid())
198
+
199
+ return x @ self.final_proj
200
+
157
201
  # memory mlp with factorized weights
158
202
  # so can tradeoff capacity for smaller chunk sizes
159
203
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -46,6 +46,7 @@ Requires-Dist: torch>=2.2
46
46
  Requires-Dist: tqdm
47
47
  Requires-Dist: x-transformers
48
48
  Provides-Extra: examples
49
+ Requires-Dist: adam-atan2-pytorch>=0.1.18; extra == 'examples'
49
50
  Requires-Dist: wandb; extra == 'examples'
50
51
  Provides-Extra: test
51
52
  Requires-Dist: pytest; extra == 'test'
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
4
+ titans_pytorch/titans.py,sha256=GFnORd9WRQBBavEUWz0DV0lxnHtLsrDEghX0RgWhZaQ,20758
5
+ titans_pytorch-0.1.7.dist-info/METADATA,sha256=ltCnRZ1MXbqxJ5-L41Z1pvfTDpLJNnk_XZ9hpfX6wrQ,4747
6
+ titans_pytorch-0.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.7.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
4
- titans_pytorch/titans.py,sha256=iF0tTTyLs3hPhJDvGVKD2PdXgpWo9xOggD_42szPwjg,19632
5
- titans_pytorch-0.1.5.dist-info/METADATA,sha256=GrCMbvIDT9gdL8JJ-U55oxFeB8TVRI2PTuvFK2QQjbk,4684
6
- titans_pytorch-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.5.dist-info/RECORD,,