titans-pytorch 0.0.36__py3-none-any.whl → 0.0.37__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -7,7 +7,7 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
- from einops import repeat, rearrange
10
+ from einops import repeat, rearrange, pack, unpack
11
11
  from einops.layers.torch import Rearrange
12
12
 
13
13
  from hyper_connections import get_init_and_expand_reduce_stream_functions
@@ -214,7 +214,12 @@ class MemoryAsContextTransformer(Module):
214
214
  if layer in neural_memory_layers:
215
215
  assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
216
216
 
217
- mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
217
+ mem = NeuralMemory(
218
+ dim = dim,
219
+ chunk_size = num_longterm_mem_tokens + segment_len,
220
+ **neural_memory_kwargs
221
+ )
222
+
218
223
  mem = init_hyper_conn(dim = dim, branch = mem)
219
224
 
220
225
  self.neural_mem_layers.append(mem)
@@ -266,7 +271,7 @@ class MemoryAsContextTransformer(Module):
266
271
  x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
267
272
 
268
273
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
269
- x = cat((mems, x), dim = -2)
274
+ x, mem_ps = pack((x, mems), 'b * d')
270
275
 
271
276
  x = inverse_segment(x)
272
277
 
@@ -283,21 +288,7 @@ class MemoryAsContextTransformer(Module):
283
288
  for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
284
289
 
285
290
  if exists(maybe_neural_mem):
286
- batch_streams = x.shape[0]
287
-
288
- x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
289
-
290
- longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
291
-
292
- longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
293
-
294
- longterm_mems = maybe_neural_mem(longterm_mems)
295
-
296
- longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
297
-
298
- x = cat((longterm_mems, x), dim = -2)
299
-
300
- x = inverse_segment(x)
291
+ mems = maybe_neural_mem(mems)
301
292
 
302
293
  x = attn(x)
303
294
 
@@ -309,7 +300,7 @@ class MemoryAsContextTransformer(Module):
309
300
 
310
301
  x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
311
302
 
312
- x = x[:, num_longterm_mem_tokens:]
303
+ x, mem = unpack(x, mem_ps, 'b * d')
313
304
 
314
305
  x = inverse_segment(x)
315
306
 
titans_pytorch/titans.py CHANGED
@@ -132,7 +132,7 @@ class NeuralMemory(Module):
132
132
  max_grad_norm: float | None = None,
133
133
  use_accelerated_scan = False,
134
134
  default_mlp_kwargs: dict = dict(
135
- depth = 4
135
+ depth = 2
136
136
  )
137
137
  ):
138
138
  super().__init__()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.36
3
+ Version: 0.0.37
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=JjKGEMBit_SvyAsxq5v08614YBcLVx3OkM6pf0rADsA,8400
4
+ titans_pytorch/titans.py,sha256=ALICGfc6p2bD2QkaibyIceVLvBIRKXmDm-w7RjnVOe4,14304
5
+ titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
+ titans_pytorch-0.0.37.dist-info/METADATA,sha256=RNokG8101_tlR0BiF-AxqYLZpXqafMSiN1Rg_pZe2-o,3938
7
+ titans_pytorch-0.0.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.37.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.37.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=xXQ9GvtvNArYidV1OOhUhCJ0pIxniElTLnL0_eIZtEE,8821
4
- titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
- titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.36.dist-info/METADATA,sha256=7Bum0wO6e6BsB7TShLBZWALyStcs1LLaEv5vvnVlQ9c,3938
7
- titans_pytorch-0.0.36.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.36.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.36.dist-info/RECORD,,