titans-pytorch 0.2.16__tar.gz → 0.2.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.16
3
+ Version: 0.2.17
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.2.16"
3
+ version = "0.2.17"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -93,6 +93,34 @@ def test_neural_mem_chaining_chunks():
93
93
 
94
94
  assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved, third_retrieved), dim = 1), atol = 1e-5)
95
95
 
96
+ def test_neural_mem_chaining_with_weight_residual():
97
+ mem = NeuralMemory(
98
+ dim = 384,
99
+ dim_head = 64,
100
+ heads = 2,
101
+ chunk_size = 64
102
+ )
103
+
104
+ mem2 = NeuralMemory(
105
+ dim = 384,
106
+ dim_head = 64,
107
+ heads = 2,
108
+ chunk_size = 64
109
+ )
110
+
111
+ seq = torch.randn(2, 256, 384)
112
+
113
+ seq, state = mem(seq)
114
+
115
+ parallel_retrieved, _ = mem2(seq, prev_weights = state.updates)
116
+
117
+ seq_first, seq_second = seq[:, :128], seq[:, 128:]
118
+
119
+ first_retrieved, state1 = mem2(seq_first, prev_weights = state.updates)
120
+ second_retrieved, state2 = mem2(seq_second, state = state1, prev_weights = state.updates)
121
+
122
+ assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-6)
123
+
96
124
  def test_neural_mem_chaining_with_batch_size():
97
125
  mem = NeuralMemory(
98
126
  dim = 384,
@@ -564,7 +564,12 @@ class NeuralMemory(Module):
564
564
  # maybe add previous layer weight
565
565
 
566
566
  if exists(prev_weights):
567
- prev_weights = prev_weights.apply(lambda t: t[:, -1:])
567
+
568
+ start_index = math.ceil(seq_index / chunk_size)
569
+ end_index = start_index + num_chunks
570
+
571
+ prev_weights = prev_weights.apply(lambda t: t[:, start_index:end_index])
572
+
568
573
  weights_for_surprise = weights_for_surprise + prev_weights
569
574
 
570
575
  # flatten batch and time if surprise depends on previous layer memory model
@@ -38,7 +38,7 @@ NEURAL_MEM_MAX_LR = 1e-1
38
38
  WINDOW_SIZE = 32
39
39
  NEURAL_MEM_SEGMENT_LEN = 4 # set smaller for more granularity for learning rate / momentum etc
40
40
  NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence
41
- NEURAL_MEM_WEIGHT_RESIDUAL = True
41
+ NEURAL_MEM_WEIGHT_RESIDUAL = False
42
42
  SLIDING_WINDOWS = True
43
43
  STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
44
44
  MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
File without changes