titans-pytorch 0.2.17__tar.gz → 0.2.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.17
3
+ Version: 0.2.18
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.2.17"
3
+ version = "0.2.18"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -149,6 +149,7 @@ def test_neural_mem_chaining_with_batch_size():
149
149
  @pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
150
150
  @pytest.mark.parametrize('neural_mem_gate_attn_output', (False, True))
151
151
  @pytest.mark.parametrize('neural_mem_segment_len', (8, 16))
152
+ @pytest.mark.parametrize('neural_mem_weight_residual', (False, True))
152
153
  @pytest.mark.parametrize('neural_mem_batch_size', (None, 64))
153
154
  def test_mac(
154
155
  seq_len,
@@ -156,7 +157,8 @@ def test_mac(
156
157
  num_longterm_mem_tokens,
157
158
  neural_mem_gate_attn_output,
158
159
  neural_mem_segment_len,
159
- neural_mem_batch_size
160
+ neural_mem_weight_residual,
161
+ neural_mem_batch_size,
160
162
  ):
161
163
  transformer = MemoryAsContextTransformer(
162
164
  num_tokens = 256,
@@ -168,6 +170,7 @@ def test_mac(
168
170
  neural_mem_gate_attn_output = neural_mem_gate_attn_output,
169
171
  neural_memory_segment_len = neural_mem_segment_len,
170
172
  neural_memory_batch_size = neural_mem_batch_size,
173
+ neural_mem_weight_residual = neural_mem_weight_residual
171
174
  )
172
175
 
173
176
  x = torch.randint(0, 256, (1, seq_len))
@@ -491,7 +491,7 @@ class MemoryAsContextTransformer(Module):
491
491
  neural_memory_layers: tuple[int, ...] | None = None,
492
492
  use_flex_attn = False,
493
493
  sliding_window_attn = False,
494
- neural_mem_weight_residual = False
494
+ neural_mem_weight_residual = False,
495
495
  ):
496
496
  super().__init__()
497
497
 
@@ -525,7 +525,10 @@ class MemoryAsContextTransformer(Module):
525
525
 
526
526
  neural_memory_layers = default(neural_memory_layers, layers)
527
527
 
528
+ # weight residual related
529
+
528
530
  self.neural_mem_weight_residual = neural_mem_weight_residual
531
+ is_first_neural_mem = True
529
532
 
530
533
  # mem, attn, and feedforward layers
531
534
 
@@ -557,9 +560,12 @@ class MemoryAsContextTransformer(Module):
557
560
  chunk_size = self.neural_memory_segment_len,
558
561
  batch_size = neural_memory_batch_size,
559
562
  model = deepcopy(neural_memory_model),
563
+ learned_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
560
564
  **neural_memory_kwargs
561
565
  )
562
566
 
567
+ is_first_neural_mem = False
568
+
563
569
  ff = FeedForward(dim = dim, mult = ff_mult)
564
570
 
565
571
  self.layers.append(ModuleList([
@@ -296,6 +296,7 @@ class NeuralMemory(Module):
296
296
  init_adaptive_step_bias = None,
297
297
  init_momentum_bias = None,
298
298
  init_decay_bias = None,
299
+ learned_weight_residual = False,
299
300
  default_model_kwargs: dict = dict(
300
301
  depth = 2
301
302
  )
@@ -438,6 +439,14 @@ class NeuralMemory(Module):
438
439
 
439
440
  self.max_mem_layer_modulation = max_mem_layer_modulation
440
441
 
442
+ # learned weight residual
443
+
444
+ self.to_learned_weight_residual_mix = Sequential(
445
+ nn.Linear(dim, heads),
446
+ Rearrange('b n h -> b h n'),
447
+ nn.Sigmoid()
448
+ ) if learned_weight_residual else None
449
+
441
450
  # allow for softclamp the gradient norms for storing memories
442
451
 
443
452
  self.max_grad_norm = max_grad_norm
@@ -563,6 +572,8 @@ class NeuralMemory(Module):
563
572
 
564
573
  # maybe add previous layer weight
565
574
 
575
+ assert xnor(exists(self.to_learned_weight_residual_mix), exists(prev_weights))
576
+
566
577
  if exists(prev_weights):
567
578
 
568
579
  start_index = math.ceil(seq_index / chunk_size)
@@ -570,6 +581,12 @@ class NeuralMemory(Module):
570
581
 
571
582
  prev_weights = prev_weights.apply(lambda t: t[:, start_index:end_index])
572
583
 
584
+ if exists(self.to_learned_weight_residual_mix):
585
+ mix = self.to_learned_weight_residual_mix(chunked_seq)
586
+ mix = rearrange(mix, 'b h n -> (b h) n')
587
+
588
+ prev_weights = prev_weights.apply(lambda t: einx.multiply('bh n, bh n ... -> bh n ...', mix, t))
589
+
573
590
  weights_for_surprise = weights_for_surprise + prev_weights
574
591
 
575
592
  # flatten batch and time if surprise depends on previous layer memory model
File without changes