titans-pytorch 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -491,7 +491,7 @@ class MemoryAsContextTransformer(Module):
491
491
  neural_memory_layers: tuple[int, ...] | None = None,
492
492
  use_flex_attn = False,
493
493
  sliding_window_attn = False,
494
- neural_mem_weight_residual = False
494
+ neural_mem_weight_residual = False,
495
495
  ):
496
496
  super().__init__()
497
497
 
@@ -525,7 +525,10 @@ class MemoryAsContextTransformer(Module):
525
525
 
526
526
  neural_memory_layers = default(neural_memory_layers, layers)
527
527
 
528
+ # weight residual related
529
+
528
530
  self.neural_mem_weight_residual = neural_mem_weight_residual
531
+ is_first_neural_mem = True
529
532
 
530
533
  # mem, attn, and feedforward layers
531
534
 
@@ -557,9 +560,12 @@ class MemoryAsContextTransformer(Module):
557
560
  chunk_size = self.neural_memory_segment_len,
558
561
  batch_size = neural_memory_batch_size,
559
562
  model = deepcopy(neural_memory_model),
563
+ learned_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
560
564
  **neural_memory_kwargs
561
565
  )
562
566
 
567
+ is_first_neural_mem = False
568
+
563
569
  ff = FeedForward(dim = dim, mult = ff_mult)
564
570
 
565
571
  self.layers.append(ModuleList([
@@ -66,6 +66,12 @@ def xnor(x, y):
66
66
  def divisible_by(num, den):
67
67
  return (num % den) == 0
68
68
 
69
+ def tuple_index_set(t: tuple, index, value):
70
+ klass = type(t)
71
+ t = list(t)
72
+ t[index] = value
73
+ return klass(*t)
74
+
69
75
  def safe_cat(inputs, dim = -2):
70
76
  inputs = tuple(filter(exists, inputs))
71
77
 
@@ -296,6 +302,7 @@ class NeuralMemory(Module):
296
302
  init_adaptive_step_bias = None,
297
303
  init_momentum_bias = None,
298
304
  init_decay_bias = None,
305
+ accept_weight_residual = False,
299
306
  default_model_kwargs: dict = dict(
300
307
  depth = 2
301
308
  )
@@ -438,6 +445,14 @@ class NeuralMemory(Module):
438
445
 
439
446
  self.max_mem_layer_modulation = max_mem_layer_modulation
440
447
 
448
+ # learned weight residual
449
+
450
+ self.to_learned_weight_residual_mix = Sequential(
451
+ nn.Linear(dim, heads),
452
+ Rearrange('b n h -> b h n'),
453
+ nn.Sigmoid()
454
+ ) if accept_weight_residual else None
455
+
441
456
  # allow for softclamp the gradient norms for storing memories
442
457
 
443
458
  self.max_grad_norm = max_grad_norm
@@ -563,6 +578,8 @@ class NeuralMemory(Module):
563
578
 
564
579
  # maybe add previous layer weight
565
580
 
581
+ assert xnor(exists(self.to_learned_weight_residual_mix), exists(prev_weights))
582
+
566
583
  if exists(prev_weights):
567
584
 
568
585
  start_index = math.ceil(seq_index / chunk_size)
@@ -570,6 +587,11 @@ class NeuralMemory(Module):
570
587
 
571
588
  prev_weights = prev_weights.apply(lambda t: t[:, start_index:end_index])
572
589
 
590
+ if exists(self.to_learned_weight_residual_mix):
591
+ mix = self.to_learned_weight_residual_mix(chunked_seq)
592
+ mix = rearrange(mix, 'b h n -> (b h) n')
593
+ prev_weights = prev_weights.apply(lambda t: einx.multiply('bh n, bh n ... -> bh n ...', mix, t))
594
+
573
595
  weights_for_surprise = weights_for_surprise + prev_weights
574
596
 
575
597
  # flatten batch and time if surprise depends on previous layer memory model
@@ -837,9 +859,9 @@ class NeuralMemory(Module):
837
859
 
838
860
  weights = last_update
839
861
 
840
- next_neural_mem_state = list(next_neural_mem_state)
841
- next_neural_mem_state[1] = last_update
842
- next_neural_mem_state = NeuralMemCache(*next_neural_mem_state)
862
+ next_neural_mem_state = tuple_index_set(next_neural_mem_state, 1, last_update)
863
+
864
+ next_neural_mem_state = tuple_index_set(next_neural_mem_state, -1, updates)
843
865
 
844
866
  # retrieve
845
867
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.17
3
+ Version: 0.2.19
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=J9MHFViWWehBOFOKx7ry_X2k8nAXaAZFFeCwGtudZyk,24942
4
+ titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
+ titans_pytorch/neural_memory.py,sha256=e1TENsK0IptpapFRtQux-Uii2MFvisoiQLzTQboHc50,26709
6
+ titans_pytorch-0.2.19.dist-info/METADATA,sha256=6frwjwq4CX7VcdCMh8Op9WeDHLRD-oeo8DlTde538UI,6812
7
+ titans_pytorch-0.2.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.19.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=QhdSPgWntfWILMJ1t0xLKgvZfPZWu9vhzZWaesftaPg,24724
4
- titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
- titans_pytorch/neural_memory.py,sha256=M7vB6AgrphY3_cdELyWXF8qS9cjYUvv1NVCUqQOlA8M,25928
6
- titans_pytorch-0.2.17.dist-info/METADATA,sha256=8oBe_7SkPmHvHVGXpoRvGWMTldfl2pzXIztOaM_qGrI,6812
7
- titans_pytorch-0.2.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.2.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.2.17.dist-info/RECORD,,