titans-pytorch 0.0.31__py3-none-any.whl → 0.0.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,6 +17,10 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
17
17
  from axial_positional_embedding import ContinuousAxialPositionalEmbedding
18
18
  from rotary_embedding_torch import RotaryEmbedding
19
19
 
20
+ # proposed neural memory
21
+
22
+ from titans_pytorch.titans import NeuralMemory
23
+
20
24
  # constants
21
25
 
22
26
  LinearNoBias = partial(Linear, bias = False)
@@ -161,7 +165,9 @@ class MemoryAsContextTransformer(Module):
161
165
  dim_head = 64,
162
166
  heads = 8,
163
167
  ff_mult = 4,
164
- num_residual_streams = 4
168
+ num_residual_streams = 4,
169
+ neural_memory_kwargs: dict = dict(),
170
+ neural_memory_layers: tuple[int, ...] | None = None,
165
171
  ):
166
172
  super().__init__()
167
173
 
@@ -181,8 +187,25 @@ class MemoryAsContextTransformer(Module):
181
187
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
182
188
 
183
189
  self.layers = ModuleList([])
190
+ self.neural_mem_layers = ModuleList([])
191
+
192
+ layers = tuple(range(1, depth + 1))
193
+ neural_memory_layers = set(default(neural_memory_layers, layers))
194
+
195
+ for layer in layers:
196
+
197
+ # neural memory
198
+
199
+ mem = None
200
+
201
+ if num_longterm_mem_tokens > 0 and layer in neural_memory_layers:
202
+ mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
203
+ mem = init_hyper_conn(dim = dim, branch = mem)
204
+
205
+ self.neural_mem_layers.append(mem)
206
+
207
+ # attention and feedforward
184
208
 
185
- for _ in range(depth):
186
209
  attn = SegmentedAttention(
187
210
  dim = dim,
188
211
  dim_head = dim_head,
@@ -221,7 +244,7 @@ class MemoryAsContextTransformer(Module):
221
244
  x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
222
245
 
223
246
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
224
- x = torch.cat((mems, x), dim = -2)
247
+ x = cat((mems, x), dim = -2)
225
248
 
226
249
  x = inverse_segment(x)
227
250
 
@@ -235,7 +258,24 @@ class MemoryAsContextTransformer(Module):
235
258
 
236
259
  x = self.expand_streams(x)
237
260
 
238
- for attn, ff in self.layers:
261
+ for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
262
+
263
+ if exists(maybe_neural_mem):
264
+ batch_streams = x.shape[0]
265
+ x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
266
+
267
+ longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
268
+
269
+ longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
270
+
271
+ longterm_mems = maybe_neural_mem(longterm_mems)
272
+
273
+ longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
274
+
275
+ x = cat((longterm_mems, x), dim = -2)
276
+
277
+ x = inverse_segment(x)
278
+
239
279
  x = attn(x)
240
280
  x = ff(x)
241
281
 
@@ -245,7 +285,7 @@ class MemoryAsContextTransformer(Module):
245
285
 
246
286
  x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
247
287
 
248
- x = x[:, self.num_longterm_mem_tokens:]
288
+ x = x[:, num_longterm_mem_tokens:]
249
289
 
250
290
  x = inverse_segment(x)
251
291
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.31
3
+ Version: 0.0.32
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=bq5RbCgA0GWLFHTrDTIKUSQhkkuCkdjEykOwjfHDs0M,6747
3
+ titans_pytorch/mac_transformer.py,sha256=ohZWVhMBtpm0Iz3w5g7pD3WXSXrvhwzZvfRplwhe1Qo,8149
4
4
  titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.31.dist-info/METADATA,sha256=bN1fVL2S_vML1oqLIA92tvBhkVvnpQN11fU4e1QVI4s,3938
7
- titans_pytorch-0.0.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.31.dist-info/RECORD,,
6
+ titans_pytorch-0.0.32.dist-info/METADATA,sha256=9X9nWfgIVS-9XIeLHQY53HXSMA6rMemPfyVC2bRrJOQ,3938
7
+ titans_pytorch-0.0.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.32.dist-info/RECORD,,