titans-pytorch 0.3.0__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.3.0"
3
+ version = "0.3.1"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -156,6 +156,7 @@ def test_neural_mem_chaining_with_batch_size():
156
156
  @pytest.mark.parametrize('neural_mem_segment_len', (8, 16))
157
157
  @pytest.mark.parametrize('neural_mem_weight_residual', (False, True))
158
158
  @pytest.mark.parametrize('neural_mem_batch_size', (None, 64))
159
+ @pytest.mark.parametrize('neural_mem_momentum', (False, True))
159
160
  def test_mac(
160
161
  seq_len,
161
162
  num_persist_mem_tokens,
@@ -164,6 +165,7 @@ def test_mac(
164
165
  neural_mem_segment_len,
165
166
  neural_mem_weight_residual,
166
167
  neural_mem_batch_size,
168
+ neural_mem_momentum
167
169
  ):
168
170
  transformer = MemoryAsContextTransformer(
169
171
  num_tokens = 256,
@@ -175,7 +177,10 @@ def test_mac(
175
177
  neural_mem_gate_attn_output = neural_mem_gate_attn_output,
176
178
  neural_memory_segment_len = neural_mem_segment_len,
177
179
  neural_memory_batch_size = neural_mem_batch_size,
178
- neural_mem_weight_residual = neural_mem_weight_residual
180
+ neural_mem_weight_residual = neural_mem_weight_residual,
181
+ neural_memory_kwargs = dict(
182
+ momentum = neural_mem_momentum
183
+ )
179
184
  )
180
185
 
181
186
  x = torch.randint(0, 256, (1, seq_len))
@@ -652,13 +652,14 @@ class NeuralMemory(Module):
652
652
  next_last_update = TensorDict()
653
653
  next_last_momentum = TensorDict()
654
654
 
655
- for (param_name, surprise), (_, last_update), (_, last_momentum) in zip(surprises.items(), past_last_update.items(), past_last_momentum.items()):
655
+ for (param_name, surprise), (_, last_update) in zip(surprises.items(), past_last_update.items()):
656
656
 
657
657
  update = surprise
658
658
 
659
659
  # derive momentum with associative scan - eq (10)
660
660
 
661
661
  if has_momentum:
662
+ last_momentum = past_last_momentum[param_name]
662
663
  update = self.assoc_scan(adaptive_momentum, surprise, prev = last_momentum) # momentum is S / surprise in the paper
663
664
  momentum = update
664
665
  next_last_momentum[param_name] = momentum[:, -1]
File without changes
File without changes
File without changes
File without changes