titans-pytorch 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -493,12 +493,20 @@ class MemoryAsContextTransformer(Module):
493
493
  use_flex_attn = False,
494
494
  sliding_window_attn = False,
495
495
  neural_mem_weight_residual = False,
496
+ token_emb: Module | None = None,
497
+ abs_pos_emb: Module | None = None
496
498
  ):
497
499
  super().__init__()
498
500
 
499
- self.token_emb = nn.Embedding(num_tokens, dim)
501
+ if not exists(token_emb):
502
+ token_emb = nn.Embedding(num_tokens, dim)
500
503
 
501
- self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
504
+ self.token_emb = token_emb
505
+
506
+ if not exists(abs_pos_emb):
507
+ abs_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
508
+
509
+ self.abs_pos_emb = abs_pos_emb
502
510
 
503
511
  # long term mem tokens
504
512
 
@@ -7,7 +7,7 @@ from itertools import zip_longest
7
7
  from collections import namedtuple
8
8
 
9
9
  import torch
10
- from torch import nn, cat, tensor, Tensor
10
+ from torch import nn, stack, cat, tensor, Tensor
11
11
  import torch.nn.functional as F
12
12
  from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict
13
13
  from torch.func import functional_call, vmap, grad
@@ -230,6 +230,7 @@ class NeuralMemory(Module):
230
230
  momentum = True,
231
231
  momentum_order = 1,
232
232
  learned_momentum_combine = False,
233
+ learned_combine_include_zeroth = False,
233
234
  pre_rmsnorm = True,
234
235
  post_rmsnorm = False,
235
236
  qk_rmsnorm = False,
@@ -399,12 +400,17 @@ class NeuralMemory(Module):
399
400
  assert momentum
400
401
  assert momentum_order > 1, 'only second order momentum allowed for now, but may allow learned combination of zeroth'
401
402
 
403
+ if learned_combine_include_zeroth:
404
+ momentum_order += 1
405
+
402
406
  self.to_learned_momentum_combine = Sequential(
403
407
  nn.Linear(dim, heads * momentum_order),
404
408
  nn.Softmax(dim = -1),
405
409
  Rearrange('b n (h o) -> o (b h) n', h = heads)
406
410
  )
407
411
 
412
+ self.learned_combine_include_zeroth = learned_combine_include_zeroth
413
+
408
414
  # per layer learning rate modulation
409
415
 
410
416
  self.to_layer_modulation = Sequential(
@@ -662,10 +668,14 @@ class NeuralMemory(Module):
662
668
 
663
669
  momentums.append(momentum)
664
670
 
665
- momentums = torch.stack(momentums)
671
+ momentums = stack(momentums)
666
672
 
667
673
  next_last_momentum[param_name] = momentums[:, :, -1] # momentums shape is Float['o bh n 1']
668
674
 
675
+ if learned_combine and self.learned_combine_include_zeroth:
676
+ # add the original surprise if learned combination of momentums
677
+ momentums = cat((rearrange(surprise, '... -> 1 ...'), momentums), dim = 0)
678
+
669
679
  if not learned_combine:
670
680
  update = momentums[-1]
671
681
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.11
3
+ Version: 0.3.14
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
+ titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
+ titans_pytorch/mac_transformer.py,sha256=F04B88GaH0wHseUIWaX6VFhOSsk_3XDQ1E8e6pvqKgQ,25170
4
+ titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
+ titans_pytorch/neural_memory.py,sha256=VmUAS1xOM0ZfearWIzQrX_P7HI69viuwrg9M7BQByeE,29349
6
+ titans_pytorch-0.3.14.dist-info/METADATA,sha256=1reoUZhvKaFPR6U0QXqJOziyss0HwHhwfJUf7oU7t-s,6817
7
+ titans_pytorch-0.3.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.14.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
- titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
- titans_pytorch/mac_transformer.py,sha256=EyqA53HBqvAr4UNZUs37LR6IltyEfA7FKEV54YzVYlg,24945
4
- titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=7YglrQaDpKS2hbpBBwx7PmqhJdjyvFEPZDt_QXmnUMM,28878
6
- titans_pytorch-0.3.11.dist-info/METADATA,sha256=xAEvavDiCj__5Bl_5UXaG__BycdUB2DzHOud-nwsn1c,6817
7
- titans_pytorch-0.3.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.11.dist-info/RECORD,,