titans-pytorch 0.0.61__py3-none-any.whl → 0.0.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,8 @@
1
1
  from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
- MemoryAttention
4
+ MemoryAttention,
5
+ FactorizedMemoryMLP
5
6
  )
6
7
 
7
8
  from titans_pytorch.mac_transformer import (
titans_pytorch/titans.py CHANGED
@@ -6,7 +6,7 @@ from functools import partial
6
6
  import torch
7
7
  from torch import nn, Tensor
8
8
  import torch.nn.functional as F
9
- from torch.nn import Linear, Module
9
+ from torch.nn import Linear, Module, Parameter, ParameterList
10
10
  from torch.func import functional_call, vmap, grad
11
11
 
12
12
  from tensordict import TensorDict
@@ -88,7 +88,7 @@ class MultiheadRMSNorm(Module):
88
88
  def __init__(self, dim, heads):
89
89
  super().__init__()
90
90
  self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
91
- self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
91
+ self.gamma = Parameter(torch.zeros(heads, 1, dim))
92
92
 
93
93
  def forward(self, x):
94
94
  return self.rmsnorm(x) * (self.gamma + 1.)
@@ -102,7 +102,10 @@ class MemoryMLP(Module):
102
102
  depth
103
103
  ):
104
104
  super().__init__()
105
- self.weights = nn.ParameterList([nn.Parameter(torch.randn(dim, dim)) for _ in range(depth)])
105
+ self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
106
+
107
+ for weight in self.weights:
108
+ nn.init.xavier_uniform_(weight)
106
109
 
107
110
  def forward(
108
111
  self,
@@ -118,13 +121,50 @@ class MemoryMLP(Module):
118
121
 
119
122
  return x
120
123
 
124
+ # memory mlp with factorized weights
125
+ # so can tradeoff capacity for smaller chunk sizes
126
+
127
+ class FactorizedMemoryMLP(Module):
128
+ def __init__(
129
+ self,
130
+ dim,
131
+ depth,
132
+ k = 32
133
+ ):
134
+ super().__init__()
135
+ self.weights = ParameterList([
136
+ ParameterList([
137
+ Parameter(torch.randn(dim, k)),
138
+ Parameter(torch.randn(k, dim)),
139
+ ]) for _ in range(depth)
140
+ ])
141
+
142
+ for weight1, weight2 in self.weights:
143
+ nn.init.xavier_uniform_(weight1)
144
+ nn.init.xavier_uniform_(weight2)
145
+
146
+ def forward(
147
+ self,
148
+ x
149
+ ):
150
+ for ind, (weight1, weight2) in enumerate(self.weights):
151
+ is_first = ind == 0
152
+
153
+ if not is_first:
154
+ x = F.silu(x)
155
+
156
+ x = x @ weight1 @ weight2
157
+
158
+ return x
159
+
121
160
  # improvised attention as memory module
122
161
 
123
162
  class MemoryAttention(Module):
124
163
  def __init__(
125
164
  self,
126
165
  dim,
127
- scale = 8.
166
+ scale = 8.,
167
+ expansion_factor = 2.
128
168
  ):
129
169
  super().__init__()
130
170
  self.scale = scale
@@ -133,10 +173,13 @@ class MemoryAttention(Module):
133
173
  nn.Parameter(torch.randn(dim, dim)), # queries
134
174
  nn.Parameter(torch.randn(dim, dim)), # keys
135
175
  nn.Parameter(torch.randn(dim, dim)), # values
136
- nn.Parameter(torch.randn(dim, dim * 2)), # ff w1
137
- nn.Parameter(torch.randn(dim * 2, dim)), # ff w2
176
+ nn.Parameter(torch.randn(dim, dim * expansion_factor)), # ff w1
177
+ nn.Parameter(torch.randn(dim * expansion_factor, dim)), # ff w2
138
178
  ])
139
179
 
180
+ for weight in self.weights:
181
+ nn.init.xavier_uniform_(weight)
182
+
140
183
  def forward(self, x):
141
184
  wq, wk, wv, ffw1, ffw2 = self.weights
142
185
 
@@ -536,6 +579,7 @@ class NeuralMemory(Module):
536
579
 
537
580
  past_weights, _ = past_state
538
581
 
582
+
539
583
  retrieved = self.retrieve_memories(seq, past_weights + updates)
540
584
 
541
585
  if not return_aux_kv_loss:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.61
3
+ Version: 0.0.62
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=R9Xu-BjqAe9ZY60IGk4aNXBx_L8THsjJ4QrkbTnRNHo,15346
4
+ titans_pytorch/titans.py,sha256=95J6UL44lOrdZSXdm7p36xC9tDeSmRBwdjig9T82PzI,17452
5
+ titans_pytorch-0.0.62.dist-info/METADATA,sha256=08Blaa9Ehyv09rSA9uWguxbhKpbrd7C53Ya13E1VbpU,4457
6
+ titans_pytorch-0.0.62.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.62.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.62.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=R9Xu-BjqAe9ZY60IGk4aNXBx_L8THsjJ4QrkbTnRNHo,15346
4
- titans_pytorch/titans.py,sha256=5wuAoDULbgXTM8Nbq8bXrW3Fd2nsn22kpERRfJOwZiU,16367
5
- titans_pytorch-0.0.61.dist-info/METADATA,sha256=Cfhqnse_9nnFNqVGo9p_kxO_LVawwv4uuZOx4anqhf0,4457
6
- titans_pytorch-0.0.61.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.61.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.61.dist-info/RECORD,,