titans-pytorch 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -303,6 +303,7 @@ class NeuralMemory(Module):
303
303
  init_momentum_bias = None,
304
304
  init_decay_bias = None,
305
305
  accept_weight_residual = False,
306
+ gated_transition = False,
306
307
  default_model_kwargs: dict = dict(
307
308
  depth = 2
308
309
  )
@@ -464,6 +465,11 @@ class NeuralMemory(Module):
464
465
  Rearrange('b n h -> (b h) n 1')
465
466
  )
466
467
 
468
+ # learned transition, as seeing instability when decreasing neural mem batch size
469
+ # perhaps it can slowly learn to adjust from early residual to fully transitioning to new weights every batch size
470
+
471
+ self.transition_gate = nn.Parameter(tensor(-5.)) if gated_transition else None
472
+
467
473
  # inits
468
474
 
469
475
  if exists(init_adaptive_step_bias):
@@ -587,7 +593,7 @@ class NeuralMemory(Module):
587
593
 
588
594
  prev_weights = prev_weights.apply(lambda t: t[:, start_index:end_index])
589
595
 
590
- if exists(self.to_learned_weight_residual_mix):
596
+ if exists(self.to_learned_weight_residual_mix) and num_chunks > 0:
591
597
  mix = self.to_learned_weight_residual_mix(chunked_seq)
592
598
  mix = rearrange(mix, 'b h n -> (b h) n')
593
599
  prev_weights = prev_weights.apply(lambda t: einx.multiply('bh n, bh n ... -> bh n ...', mix, t))
@@ -832,6 +838,13 @@ class NeuralMemory(Module):
832
838
 
833
839
  store_seqs = store_seq.split(split_sizes, dim = -2)
834
840
 
841
+ # whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
842
+
843
+ gate = None
844
+
845
+ if exists(self.transition_gate):
846
+ gate = self.transition_gate.sigmoid()
847
+
835
848
  for ind, store_seq_chunk in enumerate(store_seqs):
836
849
  is_last = ind == (len(store_seqs) - 1)
837
850
 
@@ -857,7 +870,10 @@ class NeuralMemory(Module):
857
870
 
858
871
  last_update, _ = past_state
859
872
 
860
- weights = last_update
873
+ if exists(gate):
874
+ weights = TensorDict({param_name: v1.lerp(v2, gate) for (param_name, v1), (_, v2) in zip(weights.items(), last_update.items())})
875
+ else:
876
+ weights = last_update
861
877
 
862
878
  next_neural_mem_state = tuple_index_set(next_neural_mem_state, 1, last_update)
863
879
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.20
3
+ Version: 0.2.22
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,27
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
3
  titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
4
  titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
- titans_pytorch/neural_memory.py,sha256=e1TENsK0IptpapFRtQux-Uii2MFvisoiQLzTQboHc50,26709
6
- titans_pytorch-0.2.20.dist-info/METADATA,sha256=8v5puM3CwbPaYDt2aZaqWruPj8whCt_moiQUVE1vk5U,6816
7
- titans_pytorch-0.2.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.2.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.2.20.dist-info/RECORD,,
5
+ titans_pytorch/neural_memory.py,sha256=1xKafkiJa5HxNGeD4Aw86Et_o-tePViBsiyKw-l7LR0,27511
6
+ titans_pytorch-0.2.22.dist-info/METADATA,sha256=G-iinXLii0iXtbc8-sPDZqTNBukHX4fy66CJ9l-h7hU,6816
7
+ titans_pytorch-0.2.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.22.dist-info/RECORD,,