titans-pytorch 0.2.17__tar.gz → 0.2.19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.17
3
+ Version: 0.2.19
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.2.17"
3
+ version = "0.2.19"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -105,7 +105,8 @@ def test_neural_mem_chaining_with_weight_residual():
105
105
  dim = 384,
106
106
  dim_head = 64,
107
107
  heads = 2,
108
- chunk_size = 64
108
+ chunk_size = 64,
109
+ accept_weight_residual = True
109
110
  )
110
111
 
111
112
  seq = torch.randn(2, 256, 384)
@@ -149,6 +150,7 @@ def test_neural_mem_chaining_with_batch_size():
149
150
  @pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
150
151
  @pytest.mark.parametrize('neural_mem_gate_attn_output', (False, True))
151
152
  @pytest.mark.parametrize('neural_mem_segment_len', (8, 16))
153
+ @pytest.mark.parametrize('neural_mem_weight_residual', (False, True))
152
154
  @pytest.mark.parametrize('neural_mem_batch_size', (None, 64))
153
155
  def test_mac(
154
156
  seq_len,
@@ -156,7 +158,8 @@ def test_mac(
156
158
  num_longterm_mem_tokens,
157
159
  neural_mem_gate_attn_output,
158
160
  neural_mem_segment_len,
159
- neural_mem_batch_size
161
+ neural_mem_weight_residual,
162
+ neural_mem_batch_size,
160
163
  ):
161
164
  transformer = MemoryAsContextTransformer(
162
165
  num_tokens = 256,
@@ -168,6 +171,7 @@ def test_mac(
168
171
  neural_mem_gate_attn_output = neural_mem_gate_attn_output,
169
172
  neural_memory_segment_len = neural_mem_segment_len,
170
173
  neural_memory_batch_size = neural_mem_batch_size,
174
+ neural_mem_weight_residual = neural_mem_weight_residual
171
175
  )
172
176
 
173
177
  x = torch.randint(0, 256, (1, seq_len))
@@ -491,7 +491,7 @@ class MemoryAsContextTransformer(Module):
491
491
  neural_memory_layers: tuple[int, ...] | None = None,
492
492
  use_flex_attn = False,
493
493
  sliding_window_attn = False,
494
- neural_mem_weight_residual = False
494
+ neural_mem_weight_residual = False,
495
495
  ):
496
496
  super().__init__()
497
497
 
@@ -525,7 +525,10 @@ class MemoryAsContextTransformer(Module):
525
525
 
526
526
  neural_memory_layers = default(neural_memory_layers, layers)
527
527
 
528
+ # weight residual related
529
+
528
530
  self.neural_mem_weight_residual = neural_mem_weight_residual
531
+ is_first_neural_mem = True
529
532
 
530
533
  # mem, attn, and feedforward layers
531
534
 
@@ -557,9 +560,12 @@ class MemoryAsContextTransformer(Module):
557
560
  chunk_size = self.neural_memory_segment_len,
558
561
  batch_size = neural_memory_batch_size,
559
562
  model = deepcopy(neural_memory_model),
563
+ learned_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
560
564
  **neural_memory_kwargs
561
565
  )
562
566
 
567
+ is_first_neural_mem = False
568
+
563
569
  ff = FeedForward(dim = dim, mult = ff_mult)
564
570
 
565
571
  self.layers.append(ModuleList([
@@ -66,6 +66,12 @@ def xnor(x, y):
66
66
  def divisible_by(num, den):
67
67
  return (num % den) == 0
68
68
 
69
+ def tuple_index_set(t: tuple, index, value):
70
+ klass = type(t)
71
+ t = list(t)
72
+ t[index] = value
73
+ return klass(*t)
74
+
69
75
  def safe_cat(inputs, dim = -2):
70
76
  inputs = tuple(filter(exists, inputs))
71
77
 
@@ -296,6 +302,7 @@ class NeuralMemory(Module):
296
302
  init_adaptive_step_bias = None,
297
303
  init_momentum_bias = None,
298
304
  init_decay_bias = None,
305
+ accept_weight_residual = False,
299
306
  default_model_kwargs: dict = dict(
300
307
  depth = 2
301
308
  )
@@ -438,6 +445,14 @@ class NeuralMemory(Module):
438
445
 
439
446
  self.max_mem_layer_modulation = max_mem_layer_modulation
440
447
 
448
+ # learned weight residual
449
+
450
+ self.to_learned_weight_residual_mix = Sequential(
451
+ nn.Linear(dim, heads),
452
+ Rearrange('b n h -> b h n'),
453
+ nn.Sigmoid()
454
+ ) if accept_weight_residual else None
455
+
441
456
  # allow for softclamp the gradient norms for storing memories
442
457
 
443
458
  self.max_grad_norm = max_grad_norm
@@ -563,6 +578,8 @@ class NeuralMemory(Module):
563
578
 
564
579
  # maybe add previous layer weight
565
580
 
581
+ assert xnor(exists(self.to_learned_weight_residual_mix), exists(prev_weights))
582
+
566
583
  if exists(prev_weights):
567
584
 
568
585
  start_index = math.ceil(seq_index / chunk_size)
@@ -570,6 +587,11 @@ class NeuralMemory(Module):
570
587
 
571
588
  prev_weights = prev_weights.apply(lambda t: t[:, start_index:end_index])
572
589
 
590
+ if exists(self.to_learned_weight_residual_mix):
591
+ mix = self.to_learned_weight_residual_mix(chunked_seq)
592
+ mix = rearrange(mix, 'b h n -> (b h) n')
593
+ prev_weights = prev_weights.apply(lambda t: einx.multiply('bh n, bh n ... -> bh n ...', mix, t))
594
+
573
595
  weights_for_surprise = weights_for_surprise + prev_weights
574
596
 
575
597
  # flatten batch and time if surprise depends on previous layer memory model
@@ -837,9 +859,9 @@ class NeuralMemory(Module):
837
859
 
838
860
  weights = last_update
839
861
 
840
- next_neural_mem_state = list(next_neural_mem_state)
841
- next_neural_mem_state[1] = last_update
842
- next_neural_mem_state = NeuralMemCache(*next_neural_mem_state)
862
+ next_neural_mem_state = tuple_index_set(next_neural_mem_state, 1, last_update)
863
+
864
+ next_neural_mem_state = tuple_index_set(next_neural_mem_state, -1, updates)
843
865
 
844
866
  # retrieve
845
867
 
File without changes