titans-pytorch 0.2.21__tar.gz → 0.2.23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.21
3
+ Version: 0.2.23
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.2.21"
3
+ version = "0.2.23"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -303,6 +303,7 @@ class NeuralMemory(Module):
303
303
  init_momentum_bias = None,
304
304
  init_decay_bias = None,
305
305
  accept_weight_residual = False,
306
+ gated_transition = False,
306
307
  default_model_kwargs: dict = dict(
307
308
  depth = 2
308
309
  )
@@ -464,6 +465,11 @@ class NeuralMemory(Module):
464
465
  Rearrange('b n h -> (b h) n 1')
465
466
  )
466
467
 
468
+ # learned transition, as seeing instability when decreasing neural mem batch size
469
+ # perhaps it can slowly learn to adjust from early residual to fully transitioning to new weights every batch size
470
+
471
+ self.transition_gate = nn.Parameter(tensor(-5.)) if gated_transition else None
472
+
467
473
  # inits
468
474
 
469
475
  if exists(init_adaptive_step_bias):
@@ -832,6 +838,13 @@ class NeuralMemory(Module):
832
838
 
833
839
  store_seqs = store_seq.split(split_sizes, dim = -2)
834
840
 
841
+ # whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
842
+
843
+ gate = None
844
+
845
+ if exists(self.transition_gate):
846
+ gate = self.transition_gate.sigmoid()
847
+
835
848
  for ind, store_seq_chunk in enumerate(store_seqs):
836
849
  is_last = ind == (len(store_seqs) - 1)
837
850
 
@@ -855,9 +868,14 @@ class NeuralMemory(Module):
855
868
 
856
869
  # update weights once batch size is fulfilled
857
870
 
871
+ weights = next_neural_mem_state.weights
872
+
858
873
  last_update, _ = past_state
859
874
 
860
- weights = last_update
875
+ if exists(gate):
876
+ weights = TensorDict({param_name: v1.lerp(v2, gate) for (param_name, v1), (_, v2) in zip(weights.items(), last_update.items())})
877
+ else:
878
+ weights = last_update
861
879
 
862
880
  next_neural_mem_state = tuple_index_set(next_neural_mem_state, 1, last_update)
863
881
 
File without changes