titans-pytorch 0.2.1__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +3 -20
- titans_pytorch/memory_models.py +1 -1
- titans_pytorch/neural_memory.py +56 -90
- {titans_pytorch-0.2.1.dist-info → titans_pytorch-0.2.5.dist-info}/METADATA +1 -1
- titans_pytorch-0.2.5.dist-info/RECORD +9 -0
- titans_pytorch-0.2.1.dist-info/RECORD +0 -9
- {titans_pytorch-0.2.1.dist-info → titans_pytorch-0.2.5.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.2.1.dist-info → titans_pytorch-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -488,7 +488,7 @@ class MemoryAsContextTransformer(Module):
|
|
488
488
|
neural_memory_model: Module | None = None,
|
489
489
|
neural_memory_kwargs: dict = dict(),
|
490
490
|
neural_memory_layers: tuple[int, ...] | None = None,
|
491
|
-
aux_kv_recon_loss_weight =
|
491
|
+
aux_kv_recon_loss_weight = 1.,
|
492
492
|
use_flex_attn = False,
|
493
493
|
sliding_window_attn = False,
|
494
494
|
weight_tie_memory_model = False,
|
@@ -536,11 +536,6 @@ class MemoryAsContextTransformer(Module):
|
|
536
536
|
self.weight_tie_memory_model = weight_tie_memory_model
|
537
537
|
self.prev_neural_mem_update_for_weights = default(prev_neural_mem_update_for_weights, weight_tie_memory_model)
|
538
538
|
|
539
|
-
# value residual learning for neural memory
|
540
|
-
|
541
|
-
is_first_mem = True
|
542
|
-
self.mem_add_value_residual = neural_memory_add_value_residual
|
543
|
-
|
544
539
|
# mem, attn, and feedforward layers
|
545
540
|
|
546
541
|
for layer in layers:
|
@@ -570,12 +565,9 @@ class MemoryAsContextTransformer(Module):
|
|
570
565
|
dim = dim,
|
571
566
|
chunk_size = self.neural_memory_segment_len,
|
572
567
|
model = maybe_copy(neural_memory_model),
|
573
|
-
accept_value_residual = not is_first_mem and neural_memory_add_value_residual,
|
574
568
|
**neural_memory_kwargs
|
575
569
|
)
|
576
570
|
|
577
|
-
is_first_mem = False
|
578
|
-
|
579
571
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
580
572
|
|
581
573
|
self.layers.append(ModuleList([
|
@@ -765,8 +757,6 @@ class MemoryAsContextTransformer(Module):
|
|
765
757
|
|
766
758
|
value_residual = None
|
767
759
|
|
768
|
-
mem_value_residual = None
|
769
|
-
|
770
760
|
# aux losses
|
771
761
|
|
772
762
|
kv_recon_losses = self.zero
|
@@ -794,28 +784,21 @@ class MemoryAsContextTransformer(Module):
|
|
794
784
|
mem_input, add_residual = mem_hyper_conn(x)
|
795
785
|
|
796
786
|
if not is_inferencing:
|
797
|
-
(retrieved, next_neural_mem_cache
|
787
|
+
(retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
|
798
788
|
mem_input,
|
799
789
|
return_aux_kv_loss = True,
|
800
|
-
return_values = True,
|
801
|
-
value_residual = mem_value_residual,
|
802
790
|
prev_layer_updates = neural_memory_updates
|
803
791
|
)
|
804
792
|
|
805
793
|
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
806
794
|
|
807
795
|
else:
|
808
|
-
(retrieved, next_neural_mem_cache
|
796
|
+
(retrieved, next_neural_mem_cache) = mem.forward_inference(
|
809
797
|
mem_input,
|
810
798
|
state = next(neural_mem_caches, None),
|
811
|
-
return_values = True,
|
812
|
-
value_residual = mem_value_residual,
|
813
799
|
prev_layer_updates = neural_memory_updates
|
814
800
|
)
|
815
801
|
|
816
|
-
if self.mem_add_value_residual:
|
817
|
-
mem_value_residual = next_mem_value_residual
|
818
|
-
|
819
802
|
if prev_neural_mem_update_for_weights:
|
820
803
|
neural_memory_updates = next_neural_mem_cache.updates
|
821
804
|
|
titans_pytorch/memory_models.py
CHANGED
titans_pytorch/neural_memory.py
CHANGED
@@ -67,6 +67,9 @@ def safe_cat(inputs, dim = -2):
|
|
67
67
|
def identity(t):
|
68
68
|
return t
|
69
69
|
|
70
|
+
def dict_get_shape(td):
|
71
|
+
return {k: v.shape for k, v in td.items()}
|
72
|
+
|
70
73
|
def pair(v):
|
71
74
|
return (v, v) if not isinstance(v, tuple) else v
|
72
75
|
|
@@ -258,7 +261,6 @@ class NeuralMemory(Module):
|
|
258
261
|
pre_rmsnorm = True,
|
259
262
|
post_rmsnorm = True,
|
260
263
|
qk_rmsnorm = False,
|
261
|
-
accept_value_residual = False,
|
262
264
|
max_grad_norm: float | None = None,
|
263
265
|
use_accelerated_scan = False,
|
264
266
|
activation: Module | None = None,
|
@@ -302,19 +304,34 @@ class NeuralMemory(Module):
|
|
302
304
|
nn.Sigmoid()
|
303
305
|
) if heads > 1 else None
|
304
306
|
|
305
|
-
# memory
|
307
|
+
# memory model
|
306
308
|
|
307
309
|
if not exists(model):
|
308
310
|
model = MemoryMLP(dim_head, **default_model_kwargs)
|
309
311
|
|
312
|
+
# validate memory model
|
313
|
+
|
310
314
|
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
311
315
|
|
316
|
+
test_shape = (3, 2, dim_head)
|
317
|
+
|
318
|
+
with torch.no_grad():
|
319
|
+
try:
|
320
|
+
test_input = torch.randn(test_shape)
|
321
|
+
mem_model_output = model(test_input)
|
322
|
+
except:
|
323
|
+
raise RuntimeError(f'memory model unable to accept a tensor of shape {test_shape}')
|
324
|
+
|
325
|
+
assert mem_model_output.shape == test_shape, 'output of memory model needs to be same shape as input'
|
326
|
+
|
312
327
|
# the memory is the weights of the model
|
313
328
|
|
314
329
|
self.memory_model = model
|
315
330
|
|
316
331
|
self.num_memory_parameter_tensors = len(set(model.parameters()))
|
317
332
|
|
333
|
+
self.init_weight_shape = dict_get_shape(dict(model.named_parameters()))
|
334
|
+
|
318
335
|
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
319
336
|
|
320
337
|
self.chunk_size = chunk_size
|
@@ -343,19 +360,6 @@ class NeuralMemory(Module):
|
|
343
360
|
self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
|
344
361
|
self.store_memory_loss_fn = store_memory_loss_fn
|
345
362
|
|
346
|
-
# value residual learning
|
347
|
-
|
348
|
-
self.learned_value_residual = Sequential(
|
349
|
-
LinearNoBias(dim, heads),
|
350
|
-
Rearrange('b n h -> b h n 1'),
|
351
|
-
nn.Sigmoid()
|
352
|
-
) if accept_value_residual else None
|
353
|
-
|
354
|
-
# empty memory embed
|
355
|
-
|
356
|
-
self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
|
357
|
-
nn.init.normal_(self.empty_memory_embed, std = 0.02)
|
358
|
-
|
359
363
|
# `chunk_size` refers to chunk size used for storing to memory model weights
|
360
364
|
|
361
365
|
chunk_size = self.store_chunk_size
|
@@ -417,9 +421,6 @@ class NeuralMemory(Module):
|
|
417
421
|
weights = TensorDict(dict(self.memory_model.named_parameters()))
|
418
422
|
return weights
|
419
423
|
|
420
|
-
def init_empty_memory_embed(self, batch, seq_len):
|
421
|
-
return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
|
422
|
-
|
423
424
|
def store_memories(
|
424
425
|
self,
|
425
426
|
seq,
|
@@ -428,10 +429,7 @@ class NeuralMemory(Module):
|
|
428
429
|
prev_layer_updates: dict[str, Tensor] | None = None,
|
429
430
|
return_aux_kv_loss = False,
|
430
431
|
chunk_size = None,
|
431
|
-
value_residual = None
|
432
432
|
):
|
433
|
-
assert xnor(exists(value_residual), exists(self.learned_value_residual))
|
434
|
-
|
435
433
|
seq_len, heads, chunk_size = seq.shape[-2], self.heads, default(chunk_size, self.store_chunk_size)
|
436
434
|
|
437
435
|
# handle edge case
|
@@ -446,7 +444,7 @@ class NeuralMemory(Module):
|
|
446
444
|
|
447
445
|
round_down_seq_len = round_down_multiple(seq_len, chunk_size)
|
448
446
|
|
449
|
-
seq = seq[:, :round_down_seq_len]
|
447
|
+
seq, remainder = seq[:, :round_down_seq_len], seq[:, round_down_seq_len:]
|
450
448
|
|
451
449
|
# per sample grad function
|
452
450
|
|
@@ -499,14 +497,6 @@ class NeuralMemory(Module):
|
|
499
497
|
|
500
498
|
keys = self.k_norm(keys)
|
501
499
|
|
502
|
-
# maybe value residual learning
|
503
|
-
|
504
|
-
orig_values = values
|
505
|
-
|
506
|
-
if exists(self.learned_value_residual):
|
507
|
-
mix = self.learned_value_residual(seq)
|
508
|
-
values = values.lerp(value_residual, mix)
|
509
|
-
|
510
500
|
# take care of chunking
|
511
501
|
|
512
502
|
keys, values = tuple(rearrange(t, 'b h (n c) d -> (b h n) c d', c = chunk_size) for t in (keys, values))
|
@@ -581,13 +571,15 @@ class NeuralMemory(Module):
|
|
581
571
|
if has_momentum:
|
582
572
|
next_momentum[param_name] = inverse_pack(momentum)
|
583
573
|
|
584
|
-
#
|
574
|
+
# determine next state for the storing of memories
|
585
575
|
|
586
576
|
next_state = (next_last_update, next_last_momentum)
|
587
577
|
|
578
|
+
next_store_state = NeuralMemCache(seq_len, remainder, next_state, updates)
|
579
|
+
|
588
580
|
# returns
|
589
581
|
|
590
|
-
output = (updates,
|
582
|
+
output = (updates, next_store_state)
|
591
583
|
|
592
584
|
if not return_aux_kv_loss:
|
593
585
|
return output
|
@@ -606,16 +598,18 @@ class NeuralMemory(Module):
|
|
606
598
|
|
607
599
|
seq = self.retrieve_norm(seq)
|
608
600
|
|
609
|
-
|
610
|
-
|
601
|
+
assert seq_len >= chunk_size, 'must be handled outside of retrieve'
|
602
|
+
|
603
|
+
needs_pad = chunk_size > 1
|
611
604
|
|
612
|
-
|
613
|
-
|
605
|
+
if needs_pad:
|
606
|
+
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
607
|
+
seq_len_plus_one = seq.shape[-2]
|
614
608
|
|
615
|
-
|
609
|
+
next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
|
616
610
|
|
617
|
-
|
618
|
-
|
611
|
+
padding = next_seq_len - seq_len_plus_one
|
612
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
619
613
|
|
620
614
|
# the parameters of the memory model stores the memories of the key / values
|
621
615
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
@@ -639,7 +633,9 @@ class NeuralMemory(Module):
|
|
639
633
|
|
640
634
|
# fetch values from memory model
|
641
635
|
|
642
|
-
curr_weights
|
636
|
+
if dict_get_shape(curr_weights) != self.init_weight_shape:
|
637
|
+
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
638
|
+
|
643
639
|
queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
|
644
640
|
|
645
641
|
# forward functional call
|
@@ -665,10 +661,10 @@ class NeuralMemory(Module):
|
|
665
661
|
|
666
662
|
# restore, pad with empty memory embed
|
667
663
|
|
668
|
-
|
669
|
-
|
664
|
+
if needs_pad:
|
665
|
+
values = values[:, 1:(seq_len + 1)]
|
670
666
|
|
671
|
-
return values
|
667
|
+
return values
|
672
668
|
|
673
669
|
@torch.no_grad()
|
674
670
|
def forward_inference(
|
@@ -676,8 +672,6 @@ class NeuralMemory(Module):
|
|
676
672
|
token: Tensor,
|
677
673
|
state = None,
|
678
674
|
prev_layer_updates: dict[str, Tensor] | None = None,
|
679
|
-
return_values = False,
|
680
|
-
value_residual = None,
|
681
675
|
):
|
682
676
|
|
683
677
|
# unpack previous state
|
@@ -704,12 +698,9 @@ class NeuralMemory(Module):
|
|
704
698
|
# early return empty memory, when no memories are stored for steps < first chunk size
|
705
699
|
|
706
700
|
if curr_seq_len < self.chunk_size:
|
707
|
-
|
708
|
-
|
709
|
-
output = empty_mem, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
|
701
|
+
retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
|
710
702
|
|
711
|
-
|
712
|
-
output = (*output, self.zero)
|
703
|
+
output = retrieve, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
|
713
704
|
|
714
705
|
return output
|
715
706
|
|
@@ -728,20 +719,18 @@ class NeuralMemory(Module):
|
|
728
719
|
prev_layer_updates = TensorDict(prev_layer_updates)
|
729
720
|
prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
|
730
721
|
|
731
|
-
values = None
|
732
|
-
|
733
722
|
if store_seq_cache_len == self.chunk_size:
|
734
723
|
|
735
|
-
next_updates,
|
724
|
+
next_updates, store_state = self.store_memories(
|
736
725
|
cache_store_seq,
|
737
726
|
weights,
|
738
727
|
past_state = past_states,
|
739
728
|
prev_layer_updates = prev_layer_updates,
|
740
|
-
value_residual = value_residual
|
741
729
|
)
|
742
730
|
|
743
731
|
updates = next_updates
|
744
732
|
cache_store_seq = None
|
733
|
+
next_states = store_state.states
|
745
734
|
|
746
735
|
# retrieve
|
747
736
|
|
@@ -749,14 +738,9 @@ class NeuralMemory(Module):
|
|
749
738
|
|
750
739
|
# next state tuple
|
751
740
|
|
752
|
-
|
741
|
+
next_store_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
|
753
742
|
|
754
|
-
|
755
|
-
|
756
|
-
if return_values:
|
757
|
-
output = (*output, values)
|
758
|
-
|
759
|
-
return output
|
743
|
+
return retrieved, next_store_state
|
760
744
|
|
761
745
|
def forward(
|
762
746
|
self,
|
@@ -767,50 +751,45 @@ class NeuralMemory(Module):
|
|
767
751
|
return_aux_kv_loss = False,
|
768
752
|
chunk_size = None,
|
769
753
|
store_chunk_size = None,
|
770
|
-
return_values = False,
|
771
|
-
value_residual = None,
|
772
754
|
return_next_state = False,
|
773
755
|
prev_layer_updates: dict[str, Tensor] | None = None
|
774
756
|
):
|
775
757
|
batch, seq_len = seq.shape[:2]
|
776
758
|
|
759
|
+
if not exists(mem_model_weights):
|
760
|
+
mem_model_weights = self.init_weights()
|
761
|
+
|
777
762
|
if seq_len < self.retrieve_chunk_size:
|
778
|
-
|
763
|
+
retrieved = self.retrieve_memories(seq, mem_model_weights, chunk_size = 1)
|
779
764
|
|
780
765
|
next_store_state = NeuralMemCache(seq_len, seq, None, None)
|
781
766
|
|
782
|
-
out = (
|
783
|
-
|
784
|
-
if return_values:
|
785
|
-
out = (*out, self.zero)
|
767
|
+
out = (retrieved, next_store_state)
|
786
768
|
|
787
769
|
if not return_aux_kv_loss:
|
788
770
|
return out
|
789
771
|
|
790
772
|
return out, self.zero
|
791
773
|
|
792
|
-
if not exists(mem_model_weights):
|
793
|
-
mem_model_weights = self.init_weights()
|
794
|
-
|
795
774
|
# store
|
796
775
|
|
797
776
|
store_seq = default(store_seq, seq)
|
798
777
|
|
799
|
-
|
800
|
-
store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
|
801
|
-
remainder = store_seq_len % store_chunk_size
|
802
|
-
|
803
|
-
(updates, next_state, values), aux_kv_recon_loss = self.store_memories(
|
778
|
+
(updates, next_store_state), aux_kv_recon_loss = self.store_memories(
|
804
779
|
store_seq,
|
805
780
|
mem_model_weights,
|
806
781
|
chunk_size = store_chunk_size,
|
807
782
|
prev_layer_updates = prev_layer_updates,
|
808
|
-
value_residual = value_residual,
|
809
783
|
return_aux_kv_loss = True
|
810
784
|
)
|
811
785
|
|
812
786
|
# retrieve
|
813
787
|
|
788
|
+
if exists(prev_layer_updates):
|
789
|
+
prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
790
|
+
|
791
|
+
updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
792
|
+
|
814
793
|
retrieved = self.retrieve_memories(
|
815
794
|
seq,
|
816
795
|
mem_model_weights + updates,
|
@@ -818,21 +797,8 @@ class NeuralMemory(Module):
|
|
818
797
|
prev_layer_updates = prev_layer_updates
|
819
798
|
)
|
820
799
|
|
821
|
-
# determine state for the storing of memories
|
822
|
-
# for transformer-xl like training with neural memory as well as inferencing with initial prompt
|
823
|
-
|
824
|
-
cache_store_seq = None
|
825
|
-
|
826
|
-
if remainder > 0:
|
827
|
-
cache_store_seq = store_seq[:, -remainder:]
|
828
|
-
|
829
|
-
next_store_state = NeuralMemCache(seq_len, cache_store_seq, next_state, updates)
|
830
|
-
|
831
800
|
output = (retrieved, next_store_state)
|
832
801
|
|
833
|
-
if return_values:
|
834
|
-
output = (*output, values)
|
835
|
-
|
836
802
|
if not return_aux_kv_loss:
|
837
803
|
return output
|
838
804
|
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
|
4
|
+
titans_pytorch/memory_models.py,sha256=Ew28waD9gf1wn-5Nkdc676u1I92IqzaOAw-tv0JXMwc,3777
|
5
|
+
titans_pytorch/neural_memory.py,sha256=YiBsMiqYn-Hva4yhxfaqkGV857vZIASxi5Z0TT0FC10,24606
|
6
|
+
titans_pytorch-0.2.5.dist-info/METADATA,sha256=x3RePuTDf3rUT3vtvge1X3Ry18Y3tV_swCgycbtSCjQ,6819
|
7
|
+
titans_pytorch-0.2.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.2.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.2.5.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=kqW90mpbFf1ZJ_mMkd6v9EQ5J__TwKMPy5cjHJF_26A,26742
|
4
|
-
titans_pytorch/memory_models.py,sha256=LI9T36XB6YXIvvGWRw0ZMDlGpRC6KIv03OPzME2VAaU,3772
|
5
|
-
titans_pytorch/neural_memory.py,sha256=vmKPOAlXBPXBnYPODrg_reWaIcr1xwtfQmuptGS6e5A,25559
|
6
|
-
titans_pytorch-0.2.1.dist-info/METADATA,sha256=HPdcQb4SlT-eLFzOYLMwGInEKegL4M4yIpKWt1a6DTs,6819
|
7
|
-
titans_pytorch-0.2.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.2.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.2.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|