titans-pytorch 0.0.61__tar.gz → 0.0.63__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.61
3
+ Version: 0.0.63
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,16 @@
1
+ import torch
2
+ from titans_pytorch.mac_transformer import SegmentedAttention
3
+
4
+ attn = SegmentedAttention(
5
+ dim = 512,
6
+ segment_len = 32,
7
+ num_persist_mem_tokens = 1,
8
+ use_flex_attn = True
9
+ ).cuda()
10
+
11
+ seq = torch.randn(1, 1019, 512).cuda()
12
+
13
+ out_flex, _ = attn(seq)
14
+ out_non_flex, _ = attn(seq, disable_flex_attn = True)
15
+
16
+ assert torch.allclose(out_flex, out_non_flex, atol = 1e-6)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.61"
3
+ version = "0.0.63"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,7 +1,8 @@
1
1
  from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
- MemoryAttention
4
+ MemoryAttention,
5
+ FactorizedMemoryMLP
5
6
  )
6
7
 
7
8
  from titans_pytorch.mac_transformer import (
@@ -6,7 +6,7 @@ from functools import partial
6
6
  import torch
7
7
  from torch import nn, Tensor
8
8
  import torch.nn.functional as F
9
- from torch.nn import Linear, Module
9
+ from torch.nn import Linear, Module, Parameter, ParameterList
10
10
  from torch.func import functional_call, vmap, grad
11
11
 
12
12
  from tensordict import TensorDict
@@ -88,7 +88,7 @@ class MultiheadRMSNorm(Module):
88
88
  def __init__(self, dim, heads):
89
89
  super().__init__()
90
90
  self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
91
- self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
91
+ self.gamma = Parameter(torch.zeros(heads, 1, dim))
92
92
 
93
93
  def forward(self, x):
94
94
  return self.rmsnorm(x) * (self.gamma + 1.)
@@ -102,7 +102,10 @@ class MemoryMLP(Module):
102
102
  depth
103
103
  ):
104
104
  super().__init__()
105
- self.weights = nn.ParameterList([nn.Parameter(torch.randn(dim, dim)) for _ in range(depth)])
105
+ self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
106
+
107
+ for weight in self.weights:
108
+ nn.init.xavier_uniform_(weight)
106
109
 
107
110
  def forward(
108
111
  self,
@@ -118,25 +121,66 @@ class MemoryMLP(Module):
118
121
 
119
122
  return x
120
123
 
124
+ # memory mlp with factorized weights
125
+ # so can tradeoff capacity for smaller chunk sizes
126
+
127
+ class FactorizedMemoryMLP(Module):
128
+ def __init__(
129
+ self,
130
+ dim,
131
+ depth,
132
+ k = 32
133
+ ):
134
+ super().__init__()
135
+ self.weights = ParameterList([
136
+ ParameterList([
137
+ Parameter(torch.randn(dim, k)),
138
+ Parameter(torch.randn(k, dim)),
139
+ ]) for _ in range(depth)
140
+ ])
141
+
142
+ for weight1, weight2 in self.weights:
143
+ nn.init.xavier_uniform_(weight1)
144
+ nn.init.xavier_uniform_(weight2)
145
+
146
+ def forward(
147
+ self,
148
+ x
149
+ ):
150
+ for ind, (weight1, weight2) in enumerate(self.weights):
151
+ is_first = ind == 0
152
+
153
+ if not is_first:
154
+ x = F.silu(x)
155
+
156
+ x = x @ weight1 @ weight2
157
+
158
+ return x
159
+
121
160
  # improvised attention as memory module
122
161
 
123
162
  class MemoryAttention(Module):
124
163
  def __init__(
125
164
  self,
126
165
  dim,
127
- scale = 8.
166
+ scale = 8.,
167
+ expansion_factor = 2.
128
168
  ):
129
169
  super().__init__()
130
170
  self.scale = scale
171
+ dim_ff_hidden = int(dim * expansion_factor)
131
172
 
132
173
  self.weights = nn.ParameterList([
133
174
  nn.Parameter(torch.randn(dim, dim)), # queries
134
175
  nn.Parameter(torch.randn(dim, dim)), # keys
135
176
  nn.Parameter(torch.randn(dim, dim)), # values
136
- nn.Parameter(torch.randn(dim, dim * 2)), # ff w1
137
- nn.Parameter(torch.randn(dim * 2, dim)), # ff w2
177
+ nn.Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1
178
+ nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
138
179
  ])
139
180
 
181
+ for weight in self.weights:
182
+ nn.init.xavier_uniform_(weight)
183
+
140
184
  def forward(self, x):
141
185
  wq, wk, wv, ffw1, ffw2 = self.weights
142
186
 
@@ -536,6 +580,7 @@ class NeuralMemory(Module):
536
580
 
537
581
  past_weights, _ = past_state
538
582
 
583
+
539
584
  retrieved = self.retrieve_memories(seq, past_weights + updates)
540
585
 
541
586
  if not return_aux_kv_loss:
File without changes