titans-pytorch 0.2.16__tar.gz → 0.2.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.16
3
+ Version: 0.2.18
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.2.16"
3
+ version = "0.2.18"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -93,6 +93,34 @@ def test_neural_mem_chaining_chunks():
93
93
 
94
94
  assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved, third_retrieved), dim = 1), atol = 1e-5)
95
95
 
96
+ def test_neural_mem_chaining_with_weight_residual():
97
+ mem = NeuralMemory(
98
+ dim = 384,
99
+ dim_head = 64,
100
+ heads = 2,
101
+ chunk_size = 64
102
+ )
103
+
104
+ mem2 = NeuralMemory(
105
+ dim = 384,
106
+ dim_head = 64,
107
+ heads = 2,
108
+ chunk_size = 64
109
+ )
110
+
111
+ seq = torch.randn(2, 256, 384)
112
+
113
+ seq, state = mem(seq)
114
+
115
+ parallel_retrieved, _ = mem2(seq, prev_weights = state.updates)
116
+
117
+ seq_first, seq_second = seq[:, :128], seq[:, 128:]
118
+
119
+ first_retrieved, state1 = mem2(seq_first, prev_weights = state.updates)
120
+ second_retrieved, state2 = mem2(seq_second, state = state1, prev_weights = state.updates)
121
+
122
+ assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-6)
123
+
96
124
  def test_neural_mem_chaining_with_batch_size():
97
125
  mem = NeuralMemory(
98
126
  dim = 384,
@@ -121,6 +149,7 @@ def test_neural_mem_chaining_with_batch_size():
121
149
  @pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
122
150
  @pytest.mark.parametrize('neural_mem_gate_attn_output', (False, True))
123
151
  @pytest.mark.parametrize('neural_mem_segment_len', (8, 16))
152
+ @pytest.mark.parametrize('neural_mem_weight_residual', (False, True))
124
153
  @pytest.mark.parametrize('neural_mem_batch_size', (None, 64))
125
154
  def test_mac(
126
155
  seq_len,
@@ -128,7 +157,8 @@ def test_mac(
128
157
  num_longterm_mem_tokens,
129
158
  neural_mem_gate_attn_output,
130
159
  neural_mem_segment_len,
131
- neural_mem_batch_size
160
+ neural_mem_weight_residual,
161
+ neural_mem_batch_size,
132
162
  ):
133
163
  transformer = MemoryAsContextTransformer(
134
164
  num_tokens = 256,
@@ -140,6 +170,7 @@ def test_mac(
140
170
  neural_mem_gate_attn_output = neural_mem_gate_attn_output,
141
171
  neural_memory_segment_len = neural_mem_segment_len,
142
172
  neural_memory_batch_size = neural_mem_batch_size,
173
+ neural_mem_weight_residual = neural_mem_weight_residual
143
174
  )
144
175
 
145
176
  x = torch.randint(0, 256, (1, seq_len))
@@ -491,7 +491,7 @@ class MemoryAsContextTransformer(Module):
491
491
  neural_memory_layers: tuple[int, ...] | None = None,
492
492
  use_flex_attn = False,
493
493
  sliding_window_attn = False,
494
- neural_mem_weight_residual = False
494
+ neural_mem_weight_residual = False,
495
495
  ):
496
496
  super().__init__()
497
497
 
@@ -525,7 +525,10 @@ class MemoryAsContextTransformer(Module):
525
525
 
526
526
  neural_memory_layers = default(neural_memory_layers, layers)
527
527
 
528
+ # weight residual related
529
+
528
530
  self.neural_mem_weight_residual = neural_mem_weight_residual
531
+ is_first_neural_mem = True
529
532
 
530
533
  # mem, attn, and feedforward layers
531
534
 
@@ -557,9 +560,12 @@ class MemoryAsContextTransformer(Module):
557
560
  chunk_size = self.neural_memory_segment_len,
558
561
  batch_size = neural_memory_batch_size,
559
562
  model = deepcopy(neural_memory_model),
563
+ learned_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
560
564
  **neural_memory_kwargs
561
565
  )
562
566
 
567
+ is_first_neural_mem = False
568
+
563
569
  ff = FeedForward(dim = dim, mult = ff_mult)
564
570
 
565
571
  self.layers.append(ModuleList([
@@ -296,6 +296,7 @@ class NeuralMemory(Module):
296
296
  init_adaptive_step_bias = None,
297
297
  init_momentum_bias = None,
298
298
  init_decay_bias = None,
299
+ learned_weight_residual = False,
299
300
  default_model_kwargs: dict = dict(
300
301
  depth = 2
301
302
  )
@@ -438,6 +439,14 @@ class NeuralMemory(Module):
438
439
 
439
440
  self.max_mem_layer_modulation = max_mem_layer_modulation
440
441
 
442
+ # learned weight residual
443
+
444
+ self.to_learned_weight_residual_mix = Sequential(
445
+ nn.Linear(dim, heads),
446
+ Rearrange('b n h -> b h n'),
447
+ nn.Sigmoid()
448
+ ) if learned_weight_residual else None
449
+
441
450
  # allow for softclamp the gradient norms for storing memories
442
451
 
443
452
  self.max_grad_norm = max_grad_norm
@@ -563,8 +572,21 @@ class NeuralMemory(Module):
563
572
 
564
573
  # maybe add previous layer weight
565
574
 
575
+ assert xnor(exists(self.to_learned_weight_residual_mix), exists(prev_weights))
576
+
566
577
  if exists(prev_weights):
567
- prev_weights = prev_weights.apply(lambda t: t[:, -1:])
578
+
579
+ start_index = math.ceil(seq_index / chunk_size)
580
+ end_index = start_index + num_chunks
581
+
582
+ prev_weights = prev_weights.apply(lambda t: t[:, start_index:end_index])
583
+
584
+ if exists(self.to_learned_weight_residual_mix):
585
+ mix = self.to_learned_weight_residual_mix(chunked_seq)
586
+ mix = rearrange(mix, 'b h n -> (b h) n')
587
+
588
+ prev_weights = prev_weights.apply(lambda t: einx.multiply('bh n, bh n ... -> bh n ...', mix, t))
589
+
568
590
  weights_for_surprise = weights_for_surprise + prev_weights
569
591
 
570
592
  # flatten batch and time if surprise depends on previous layer memory model
@@ -38,7 +38,7 @@ NEURAL_MEM_MAX_LR = 1e-1
38
38
  WINDOW_SIZE = 32
39
39
  NEURAL_MEM_SEGMENT_LEN = 4 # set smaller for more granularity for learning rate / momentum etc
40
40
  NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence
41
- NEURAL_MEM_WEIGHT_RESIDUAL = True
41
+ NEURAL_MEM_WEIGHT_RESIDUAL = False
42
42
  SLIDING_WINDOWS = True
43
43
  STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
44
44
  MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
File without changes