titans-pytorch 0.0.37__py3-none-any.whl → 0.0.38__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

@@ -288,7 +288,8 @@ class MemoryAsContextTransformer(Module):
288
288
  for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
289
289
 
290
290
  if exists(maybe_neural_mem):
291
- mems = maybe_neural_mem(mems)
291
+ x = maybe_neural_mem(x)
292
+
292
293
 
293
294
  x = attn(x)
294
295
 
@@ -300,7 +301,7 @@ class MemoryAsContextTransformer(Module):
300
301
 
301
302
  x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
302
303
 
303
- x, mem = unpack(x, mem_ps, 'b * d')
304
+ x, _ = unpack(x, mem_ps, 'b * d')
304
305
 
305
306
  x = inverse_segment(x)
306
307
 
titans_pytorch/titans.py CHANGED
@@ -27,9 +27,7 @@ n - sequence
27
27
  d - feature dimension
28
28
  c - intra-chunk
29
29
  """
30
-
31
- # constants
32
-
30
+ 7
33
31
  LinearNoBias = partial(Linear, bias = False)
34
32
 
35
33
  # functions
@@ -390,7 +388,10 @@ class NeuralMemory(Module):
390
388
 
391
389
  padding = next_seq_len - curtailed_seq_len
392
390
 
393
- seq = pad_at_dim(seq, (0, padding), dim = 1)
391
+ needs_pad = padding > 0
392
+
393
+ if needs_pad:
394
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
394
395
 
395
396
  # the parameters of the memory model stores the memories of the key / values
396
397
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
@@ -442,7 +443,9 @@ class NeuralMemory(Module):
442
443
  empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
443
444
  values = torch.cat((empty_memory_embeds, values), dim = -2)
444
445
 
445
- values = values[:, :-padding]
446
+ if needs_pad:
447
+ values = values[:, :-padding]
448
+
446
449
  return values
447
450
 
448
451
  def forward(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.37
3
+ Version: 0.0.38
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=5koIfEulJ841FNrs6URZfW2dp9LMuHzMkaySDrlbuP0,8393
4
+ titans_pytorch/titans.py,sha256=bv2Ceq-_4nNb5FNx4hLd2inC93m5MmJxO2-Mbf6PKK4,14378
5
+ titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
+ titans_pytorch-0.0.38.dist-info/METADATA,sha256=L6tEQTEOXCeAU_BuRLbwUO0-gmnbJE-WQNAZ83BNCWA,3938
7
+ titans_pytorch-0.0.38.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.38.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.38.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=JjKGEMBit_SvyAsxq5v08614YBcLVx3OkM6pf0rADsA,8400
4
- titans_pytorch/titans.py,sha256=ALICGfc6p2bD2QkaibyIceVLvBIRKXmDm-w7RjnVOe4,14304
5
- titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.37.dist-info/METADATA,sha256=RNokG8101_tlR0BiF-AxqYLZpXqafMSiN1Rg_pZe2-o,3938
7
- titans_pytorch-0.0.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.37.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.37.dist-info/RECORD,,