titans-pytorch 0.1.36__py3-none-any.whl → 0.1.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,8 @@
1
1
  from titans_pytorch.neural_memory import (
2
2
  NeuralMemory,
3
+ )
4
+
5
+ from titans_pytorch.memory_models import (
3
6
  MemoryMLP,
4
7
  MemoryAttention,
5
8
  FactorizedMemoryMLP,
@@ -0,0 +1,153 @@
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import Module, ModuleList, Parameter, ParameterList
5
+
6
+ class MemoryMLP(Module):
7
+ def __init__(
8
+ self,
9
+ dim,
10
+ depth
11
+ ):
12
+ super().__init__()
13
+ self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
14
+
15
+ for weight in self.weights:
16
+ nn.init.xavier_uniform_(weight)
17
+
18
+ def forward(
19
+ self,
20
+ x
21
+ ):
22
+ for ind, weight in enumerate(self.weights):
23
+ is_first = ind == 0
24
+
25
+ if not is_first:
26
+ x = F.silu(x)
27
+
28
+ x = x @ weight
29
+
30
+ return x
31
+
32
+ # memory mlp, but with gated residual + final projection
33
+
34
+ class GatedResidualMemoryMLP(Module):
35
+ def __init__(
36
+ self,
37
+ dim,
38
+ depth,
39
+ expansion_factor = 2.
40
+ ):
41
+ super().__init__()
42
+ dim_hidden = int(dim * expansion_factor)
43
+
44
+ self.weights = ParameterList([
45
+ ParameterList([
46
+ Parameter(torch.randn(dim, dim_hidden)),
47
+ Parameter(torch.randn(dim_hidden, dim)),
48
+ Parameter(torch.randn(dim * 2, dim)),
49
+ ]) for _ in range(depth)
50
+ ])
51
+
52
+ self.final_proj = Parameter(torch.randn(dim, dim))
53
+
54
+ for param in self.parameters():
55
+ nn.init.xavier_uniform_(param)
56
+
57
+ def forward(
58
+ self,
59
+ x
60
+ ):
61
+ for weight1, weight2, to_gates in self.weights:
62
+ res = x
63
+
64
+ hidden = x @ weight1
65
+ hidden = F.silu(hidden)
66
+ branch_out = hidden @ weight2
67
+
68
+ # gated residual
69
+
70
+ gates = cat((branch_out, res), dim = -1) @ to_gates
71
+ x = res.lerp(branch_out, gates.sigmoid())
72
+
73
+ return x @ self.final_proj
74
+
75
+ # memory mlp with factorized weights
76
+ # so can tradeoff capacity for smaller chunk sizes
77
+
78
+ class FactorizedMemoryMLP(Module):
79
+ def __init__(
80
+ self,
81
+ dim,
82
+ depth,
83
+ k = 32
84
+ ):
85
+ super().__init__()
86
+ self.weights = ParameterList([
87
+ ParameterList([
88
+ Parameter(torch.randn(dim, k)),
89
+ Parameter(torch.randn(k, dim)),
90
+ ]) for _ in range(depth)
91
+ ])
92
+
93
+ for weight1, weight2 in self.weights:
94
+ nn.init.xavier_uniform_(weight1)
95
+ nn.init.xavier_uniform_(weight2)
96
+
97
+ def forward(
98
+ self,
99
+ x
100
+ ):
101
+ for ind, (weight1, weight2) in enumerate(self.weights):
102
+ is_first = ind == 0
103
+
104
+ if not is_first:
105
+ x = F.silu(x)
106
+
107
+ x = x @ weight1 @ weight2
108
+
109
+ return x
110
+
111
+ # improvised attention as memory module
112
+
113
+ class MemoryAttention(Module):
114
+ def __init__(
115
+ self,
116
+ dim,
117
+ scale = 8.,
118
+ expansion_factor = 2.
119
+ ):
120
+ super().__init__()
121
+ self.scale = scale
122
+ dim_ff_hidden = int(dim * expansion_factor)
123
+
124
+ self.weights = nn.ParameterList([
125
+ nn.Parameter(torch.randn(dim, dim)), # queries
126
+ nn.Parameter(torch.randn(dim, dim)), # keys
127
+ nn.Parameter(torch.randn(dim, dim)), # values
128
+ nn.Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1
129
+ nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
130
+ ])
131
+
132
+ for weight in self.weights:
133
+ nn.init.xavier_uniform_(weight)
134
+
135
+ def forward(self, x):
136
+ wq, wk, wv, ffw1, ffw2 = self.weights
137
+
138
+ q = F.normalize(x @ wq, dim = -1)
139
+ k = F.normalize(x @ wk, dim = -1)
140
+ v = x @ wv
141
+
142
+ attn_out = F.scaled_dot_product_attention(
143
+ q, k, v,
144
+ scale = self.scale,
145
+ is_causal = True
146
+ )
147
+
148
+ x = x + attn_out
149
+
150
+ h = F.silu(x @ ffw1)
151
+ out = h @ ffw2
152
+
153
+ return out
@@ -19,6 +19,10 @@ from titans_pytorch.associative_scan import (
19
19
  pad_at_dim
20
20
  )
21
21
 
22
+ from titans_pytorch.memory_models import(
23
+ MemoryMLP
24
+ )
25
+
22
26
  import einx
23
27
  from einops import rearrange, repeat, reduce, pack, unpack
24
28
  from einops.layers.torch import Rearrange, Reduce
@@ -169,157 +173,6 @@ class AttentionPool(Module):
169
173
 
170
174
  return reduce(x * attn, 'b n c d -> b n d', 'sum')
171
175
 
172
- # classes
173
-
174
- class MemoryMLP(Module):
175
- def __init__(
176
- self,
177
- dim,
178
- depth
179
- ):
180
- super().__init__()
181
- self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
182
-
183
- for weight in self.weights:
184
- nn.init.xavier_uniform_(weight)
185
-
186
- def forward(
187
- self,
188
- x
189
- ):
190
- for ind, weight in enumerate(self.weights):
191
- is_first = ind == 0
192
-
193
- if not is_first:
194
- x = F.silu(x)
195
-
196
- x = x @ weight
197
-
198
- return x
199
-
200
- # memory mlp, but with gated residual + final projection
201
-
202
- class GatedResidualMemoryMLP(Module):
203
- def __init__(
204
- self,
205
- dim,
206
- depth,
207
- expansion_factor = 2.
208
- ):
209
- super().__init__()
210
- dim_hidden = int(dim * expansion_factor)
211
-
212
- self.weights = ParameterList([
213
- ParameterList([
214
- Parameter(torch.randn(dim, dim_hidden)),
215
- Parameter(torch.randn(dim_hidden, dim)),
216
- Parameter(torch.randn(dim * 2, dim)),
217
- ]) for _ in range(depth)
218
- ])
219
-
220
- self.final_proj = Parameter(torch.randn(dim, dim))
221
-
222
- for param in self.parameters():
223
- nn.init.xavier_uniform_(param)
224
-
225
- def forward(
226
- self,
227
- x
228
- ):
229
- for weight1, weight2, to_gates in self.weights:
230
- res = x
231
-
232
- hidden = x @ weight1
233
- hidden = F.silu(hidden)
234
- branch_out = hidden @ weight2
235
-
236
- # gated residual
237
-
238
- gates = cat((branch_out, res), dim = -1) @ to_gates
239
- x = res.lerp(branch_out, gates.sigmoid())
240
-
241
- return x @ self.final_proj
242
-
243
- # memory mlp with factorized weights
244
- # so can tradeoff capacity for smaller chunk sizes
245
-
246
- class FactorizedMemoryMLP(Module):
247
- def __init__(
248
- self,
249
- dim,
250
- depth,
251
- k = 32
252
- ):
253
- super().__init__()
254
- self.weights = ParameterList([
255
- ParameterList([
256
- Parameter(torch.randn(dim, k)),
257
- Parameter(torch.randn(k, dim)),
258
- ]) for _ in range(depth)
259
- ])
260
-
261
- for weight1, weight2 in self.weights:
262
- nn.init.xavier_uniform_(weight1)
263
- nn.init.xavier_uniform_(weight2)
264
-
265
- def forward(
266
- self,
267
- x
268
- ):
269
- for ind, (weight1, weight2) in enumerate(self.weights):
270
- is_first = ind == 0
271
-
272
- if not is_first:
273
- x = F.silu(x)
274
-
275
- x = x @ weight1 @ weight2
276
-
277
- return x
278
-
279
- # improvised attention as memory module
280
-
281
- class MemoryAttention(Module):
282
- def __init__(
283
- self,
284
- dim,
285
- scale = 8.,
286
- expansion_factor = 2.
287
- ):
288
- super().__init__()
289
- self.scale = scale
290
- dim_ff_hidden = int(dim * expansion_factor)
291
-
292
- self.weights = nn.ParameterList([
293
- nn.Parameter(torch.randn(dim, dim)), # queries
294
- nn.Parameter(torch.randn(dim, dim)), # keys
295
- nn.Parameter(torch.randn(dim, dim)), # values
296
- nn.Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1
297
- nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
298
- ])
299
-
300
- for weight in self.weights:
301
- nn.init.xavier_uniform_(weight)
302
-
303
- def forward(self, x):
304
- wq, wk, wv, ffw1, ffw2 = self.weights
305
-
306
- q = F.normalize(x @ wq, dim = -1)
307
- k = F.normalize(x @ wk, dim = -1)
308
- v = x @ wv
309
-
310
- attn_out = F.scaled_dot_product_attention(
311
- q, k, v,
312
- scale = self.scale,
313
- is_causal = True
314
- )
315
-
316
- x = x + attn_out
317
-
318
- h = F.silu(x @ ffw1)
319
- out = h @ ffw2
320
-
321
- return out
322
-
323
176
  # associative scan wrapper
324
177
 
325
178
  class AssocScan(Module):
@@ -870,6 +723,8 @@ class NeuralMemory(Module):
870
723
  prev_layer_updates = TensorDict(prev_layer_updates)
871
724
  prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
872
725
 
726
+ values = None
727
+
873
728
  if store_seq_cache_len == self.chunk_size:
874
729
 
875
730
  next_updates, next_states, values = self.store_memories(
@@ -917,7 +772,7 @@ class NeuralMemory(Module):
917
772
  if seq_len < self.retrieve_chunk_size:
918
773
  out = self.init_empty_memory_embed(batch, seq_len)
919
774
 
920
- next_store_state = (seq_len, seq, None, None)
775
+ next_store_state = NeuralMemCache(seq_len, seq, None, None)
921
776
 
922
777
  out = (out, next_store_state)
923
778
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.36
3
+ Version: 0.1.37
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=KBwo-Fr_fDzVaAa7xg1ggEpNlE4vRUoGMEjB-I2ZWTU,26463
4
+ titans_pytorch/memory_models.py,sha256=LI9T36XB6YXIvvGWRw0ZMDlGpRC6KIv03OPzME2VAaU,3772
5
+ titans_pytorch/neural_memory.py,sha256=B6nTdsq7Tp6lfmUwQakOtODImwZzMCDkDfIv5CIKlbQ,25453
6
+ titans_pytorch-0.1.37.dist-info/METADATA,sha256=BdkDj71kq320M_vbZOGPrRBR91ycgqBJaUNm9m7tA5U,6826
7
+ titans_pytorch-0.1.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.1.37.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.1.37.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=KBwo-Fr_fDzVaAa7xg1ggEpNlE4vRUoGMEjB-I2ZWTU,26463
4
- titans_pytorch/neural_memory.py,sha256=U4qJvwN3otGPifQJfRBIFilduWYI22DAsTtdc3kOYFY,29009
5
- titans_pytorch-0.1.36.dist-info/METADATA,sha256=dx1D5t2njrP-1zG4JzLIrimS6rvaN1UzFISm0wrf-l8,6826
6
- titans_pytorch-0.1.36.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.36.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.36.dist-info/RECORD,,