titans-pytorch 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +6 -48
- titans_pytorch/neural_memory.py +90 -119
- {titans_pytorch-0.2.8.dist-info → titans_pytorch-0.2.10.dist-info}/METADATA +1 -1
- titans_pytorch-0.2.10.dist-info/RECORD +9 -0
- titans_pytorch-0.2.8.dist-info/RECORD +0 -9
- {titans_pytorch-0.2.8.dist-info → titans_pytorch-0.2.10.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.2.8.dist-info → titans_pytorch-0.2.10.dist-info}/licenses/LICENSE +0 -0
@@ -488,11 +488,8 @@ class MemoryAsContextTransformer(Module):
|
|
488
488
|
neural_memory_model: Module | None = None,
|
489
489
|
neural_memory_kwargs: dict = dict(),
|
490
490
|
neural_memory_layers: tuple[int, ...] | None = None,
|
491
|
-
aux_kv_recon_loss_weight = 1.,
|
492
491
|
use_flex_attn = False,
|
493
492
|
sliding_window_attn = False,
|
494
|
-
weight_tie_memory_model = False,
|
495
|
-
prev_neural_mem_update_for_weights = None
|
496
493
|
):
|
497
494
|
super().__init__()
|
498
495
|
|
@@ -526,16 +523,6 @@ class MemoryAsContextTransformer(Module):
|
|
526
523
|
|
527
524
|
neural_memory_layers = default(neural_memory_layers, layers)
|
528
525
|
|
529
|
-
# weight tying neural memory model
|
530
|
-
|
531
|
-
maybe_copy = deepcopy if not weight_tie_memory_model else identity
|
532
|
-
|
533
|
-
if weight_tie_memory_model:
|
534
|
-
assert exists(neural_memory_model), '`neural_memory_model` must be explicitly set'
|
535
|
-
|
536
|
-
self.weight_tie_memory_model = weight_tie_memory_model
|
537
|
-
self.prev_neural_mem_update_for_weights = default(prev_neural_mem_update_for_weights, weight_tie_memory_model)
|
538
|
-
|
539
526
|
# mem, attn, and feedforward layers
|
540
527
|
|
541
528
|
for layer in layers:
|
@@ -564,7 +551,7 @@ class MemoryAsContextTransformer(Module):
|
|
564
551
|
mem = NeuralMemory(
|
565
552
|
dim = dim,
|
566
553
|
chunk_size = self.neural_memory_segment_len,
|
567
|
-
model =
|
554
|
+
model = deepcopy(neural_memory_model),
|
568
555
|
**neural_memory_kwargs
|
569
556
|
)
|
570
557
|
|
@@ -585,10 +572,7 @@ class MemoryAsContextTransformer(Module):
|
|
585
572
|
|
586
573
|
self.gate_attn_output = neural_mem_gate_attn_output
|
587
574
|
|
588
|
-
#
|
589
|
-
|
590
|
-
self.has_aux_kv_recon_loss = aux_kv_recon_loss_weight > 0.
|
591
|
-
self.aux_kv_recon_loss_weight = aux_kv_recon_loss_weight
|
575
|
+
# zero for maybe aux loss + device
|
592
576
|
|
593
577
|
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
594
578
|
|
@@ -696,7 +680,7 @@ class MemoryAsContextTransformer(Module):
|
|
696
680
|
|
697
681
|
# math
|
698
682
|
|
699
|
-
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size
|
683
|
+
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
|
700
684
|
|
701
685
|
seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
|
702
686
|
|
@@ -749,18 +733,10 @@ class MemoryAsContextTransformer(Module):
|
|
749
733
|
next_kv_caches = []
|
750
734
|
next_neural_mem_caches = []
|
751
735
|
|
752
|
-
# weight tied neural memory
|
753
|
-
|
754
|
-
neural_memory_updates = None
|
755
|
-
|
756
736
|
# value residual
|
757
737
|
|
758
738
|
value_residual = None
|
759
739
|
|
760
|
-
# aux losses
|
761
|
-
|
762
|
-
kv_recon_losses = self.zero
|
763
|
-
|
764
740
|
# when inferencing, only do one token at a time
|
765
741
|
|
766
742
|
if is_inferencing:
|
@@ -784,24 +760,16 @@ class MemoryAsContextTransformer(Module):
|
|
784
760
|
mem_input, add_residual = mem_hyper_conn(x)
|
785
761
|
|
786
762
|
if not is_inferencing:
|
787
|
-
|
788
|
-
mem_input
|
789
|
-
return_aux_kv_loss = True,
|
790
|
-
prev_layer_updates = neural_memory_updates
|
763
|
+
retrieved, next_neural_mem_cache = mem(
|
764
|
+
mem_input
|
791
765
|
)
|
792
766
|
|
793
|
-
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
794
|
-
|
795
767
|
else:
|
796
768
|
(retrieved, next_neural_mem_cache) = mem.forward_inference(
|
797
769
|
mem_input,
|
798
770
|
state = next(neural_mem_caches, None),
|
799
|
-
prev_layer_updates = neural_memory_updates
|
800
771
|
)
|
801
772
|
|
802
|
-
if prev_neural_mem_update_for_weights:
|
803
|
-
neural_memory_updates = next_neural_mem_cache.updates
|
804
|
-
|
805
773
|
if self.gate_attn_output:
|
806
774
|
attn_out_gates = retrieved.sigmoid()
|
807
775
|
else:
|
@@ -883,14 +851,4 @@ class MemoryAsContextTransformer(Module):
|
|
883
851
|
|
884
852
|
return logits, next_cache
|
885
853
|
|
886
|
-
|
887
|
-
|
888
|
-
losses = ar_loss
|
889
|
-
|
890
|
-
if self.has_aux_kv_recon_loss:
|
891
|
-
losses = losses + kv_recon_losses * self.aux_kv_recon_loss_weight
|
892
|
-
|
893
|
-
if not return_loss_breakdown:
|
894
|
-
return losses
|
895
|
-
|
896
|
-
return losses, (ar_loss, kv_recon_losses)
|
854
|
+
return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
titans_pytorch/neural_memory.py
CHANGED
@@ -38,7 +38,13 @@ w - num memory network weight parameters
|
|
38
38
|
|
39
39
|
LinearNoBias = partial(Linear, bias = False)
|
40
40
|
|
41
|
-
NeuralMemCache = namedtuple('NeuralMemCache', [
|
41
|
+
NeuralMemCache = namedtuple('NeuralMemCache', [
|
42
|
+
'seq',
|
43
|
+
'weights',
|
44
|
+
'cache_store_segment',
|
45
|
+
'states',
|
46
|
+
'updates',
|
47
|
+
])
|
42
48
|
|
43
49
|
# functions
|
44
50
|
|
@@ -70,6 +76,12 @@ def safe_cat(inputs, dim = -2):
|
|
70
76
|
def dict_get_shape(td):
|
71
77
|
return {k: v.shape for k, v in td.items()}
|
72
78
|
|
79
|
+
def rearrange_dict_values(td, pattern, **kwargs):
|
80
|
+
return td.apply(lambda t: rearrange(t, pattern, **kwargs))
|
81
|
+
|
82
|
+
def repeat_dict_values(td, pattern, **kwargs):
|
83
|
+
return td.apply(lambda t: repeat(t, pattern, **kwargs))
|
84
|
+
|
73
85
|
def pair(v):
|
74
86
|
return (v, v) if not isinstance(v, tuple) else v
|
75
87
|
|
@@ -106,6 +118,9 @@ def softclamp_max(t, max_value):
|
|
106
118
|
return ((t / half_max_value).tanh() * half_max_value) + half_max_value
|
107
119
|
|
108
120
|
def softclamp_grad_norm(t, max_value):
|
121
|
+
if t.numel() == 0:
|
122
|
+
return t
|
123
|
+
|
109
124
|
t, inverse = pack_one_with_inverse(t, 'bn *')
|
110
125
|
|
111
126
|
norm = t.norm(dim = -1, keepdim = True)
|
@@ -195,6 +210,12 @@ class AssocScan(Module):
|
|
195
210
|
):
|
196
211
|
remove_prev = default(remove_prev, exists(prev))
|
197
212
|
|
213
|
+
inputs, inverse_pack_weight_shape = pack_one_with_inverse(inputs, 'b n *')
|
214
|
+
gates, _ = pack_one_with_inverse(gates, 'b n *')
|
215
|
+
|
216
|
+
if exists(prev):
|
217
|
+
prev, _ = pack_one_with_inverse(prev, 'b *')
|
218
|
+
|
198
219
|
if exists(prev):
|
199
220
|
inputs, _ = pack([prev, inputs], 'b * d')
|
200
221
|
gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
|
@@ -205,7 +226,7 @@ class AssocScan(Module):
|
|
205
226
|
if remove_prev:
|
206
227
|
out = out[:, 1:]
|
207
228
|
|
208
|
-
return out
|
229
|
+
return inverse_pack_weight_shape(out)
|
209
230
|
|
210
231
|
from accelerated_scan.triton import scan as triton_scan
|
211
232
|
from accelerated_scan.warp import scan as warp_scan
|
@@ -226,6 +247,7 @@ class AssocScan(Module):
|
|
226
247
|
|
227
248
|
outputs = outputs[..., :seq_len]
|
228
249
|
outputs = rearrange(outputs, 'b d n -> b n d')
|
250
|
+
|
229
251
|
return outputs
|
230
252
|
|
231
253
|
out = accelerate_scan_fn(gates, inputs)
|
@@ -233,7 +255,7 @@ class AssocScan(Module):
|
|
233
255
|
if remove_prev:
|
234
256
|
out = out[:, 1:]
|
235
257
|
|
236
|
-
return out
|
258
|
+
return inverse_pack_weight_shape(out)
|
237
259
|
|
238
260
|
# main neural memory
|
239
261
|
|
@@ -253,7 +275,7 @@ class NeuralMemory(Module):
|
|
253
275
|
model: Module | None = None,
|
254
276
|
store_memory_loss_fn: Callable = default_loss_fn,
|
255
277
|
adaptive_step_transform: Callable | None = None,
|
256
|
-
default_step_transform_max_lr =
|
278
|
+
default_step_transform_max_lr = 1.,
|
257
279
|
per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
|
258
280
|
max_mem_layer_modulation = 1e1, # max of 10.
|
259
281
|
attn_pool_chunks = False,
|
@@ -342,14 +364,13 @@ class NeuralMemory(Module):
|
|
342
364
|
pred = functional_call(self.memory_model, params, inputs)
|
343
365
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
344
366
|
weighted_loss = loss * loss_weights
|
345
|
-
return weighted_loss.sum()
|
367
|
+
return weighted_loss.sum()
|
346
368
|
|
347
369
|
# two functions
|
348
370
|
|
349
|
-
grad_fn = grad(forward_and_loss
|
371
|
+
grad_fn = grad(forward_and_loss)
|
350
372
|
|
351
|
-
self.per_sample_grad_fn = vmap(grad_fn, in_dims = (
|
352
|
-
self.per_sample_grad_fn_expanded_weights = vmap(grad_fn, in_dims = (0,) * 4)
|
373
|
+
self.per_sample_grad_fn = vmap(grad_fn, in_dims = (0, 0, 0, 0))
|
353
374
|
|
354
375
|
# queries for retrieving from the model
|
355
376
|
|
@@ -417,56 +438,56 @@ class NeuralMemory(Module):
|
|
417
438
|
|
418
439
|
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
419
440
|
|
420
|
-
def init_weights(
|
441
|
+
def init_weights(
|
442
|
+
self,
|
443
|
+
batch,
|
444
|
+
):
|
421
445
|
weights = TensorDict(dict(self.memory_model.named_parameters()))
|
446
|
+
weights = repeat_dict_values(weights, '... -> bh ...', bh = batch * self.heads)
|
422
447
|
return weights
|
423
448
|
|
449
|
+
def init_momentum(
|
450
|
+
self,
|
451
|
+
batch,
|
452
|
+
):
|
453
|
+
weights = TensorDict(dict(self.memory_model.named_parameters()))
|
454
|
+
zeros = weights.clone().zero_()
|
455
|
+
zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
|
456
|
+
return zeros
|
457
|
+
|
424
458
|
def store_memories(
|
425
459
|
self,
|
426
460
|
seq,
|
427
|
-
weights: dict[str, Tensor],
|
461
|
+
weights: dict[str, Tensor] | None = None,
|
428
462
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
429
|
-
prev_layer_updates: dict[str, Tensor] | None = None,
|
430
|
-
return_aux_kv_loss = False,
|
431
463
|
chunk_size = None,
|
432
464
|
):
|
433
|
-
seq_len, heads, chunk_size = seq.shape[
|
434
|
-
|
435
|
-
# handle edge case
|
436
|
-
|
437
|
-
if seq_len < chunk_size:
|
438
|
-
return TensorDict(weights).clone().zero_(), self.zero
|
439
|
-
|
440
|
-
seq = self.store_norm(seq)
|
465
|
+
batch, seq_len, heads, chunk_size = *seq.shape[:2], self.heads, default(chunk_size, self.store_chunk_size)
|
441
466
|
|
442
467
|
# curtail sequence by multiple of the chunk size
|
443
468
|
# only a complete chunk of the sequence provides the memory for the next chunk
|
444
469
|
|
445
470
|
round_down_seq_len = round_down_multiple(seq_len, chunk_size)
|
471
|
+
num_chunks = round_down_seq_len // chunk_size
|
446
472
|
|
447
473
|
seq, remainder = seq[:, :round_down_seq_len], seq[:, round_down_seq_len:]
|
448
474
|
|
449
|
-
#
|
450
|
-
|
451
|
-
per_sample_grad_fn = self.per_sample_grad_fn
|
452
|
-
|
475
|
+
# init weights if needed
|
453
476
|
# weights of the memory network
|
454
477
|
|
478
|
+
if not exists(weights):
|
479
|
+
weights = self.init_weights(batch)
|
480
|
+
|
455
481
|
weights = TensorDict(weights)
|
456
482
|
|
457
483
|
# allow for neural memory of a previous layer to influence surprise of current layer
|
458
484
|
|
459
|
-
weights_for_surprise = weights
|
460
|
-
|
461
|
-
if exists(prev_layer_updates):
|
462
|
-
prev_layer_updates = TensorDict(prev_layer_updates)
|
463
|
-
|
464
|
-
weights_for_surprise = weights_for_surprise + prev_layer_updates
|
465
|
-
|
466
|
-
per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
|
485
|
+
weights_for_surprise = repeat_dict_values(weights, 'b ... -> b n ...', n = num_chunks)
|
467
486
|
|
468
487
|
# derive learned hparams for optimization of memory network
|
469
488
|
|
489
|
+
seq = self.store_norm(seq)
|
490
|
+
|
470
491
|
adaptive_lr = self.to_adaptive_step(seq)
|
471
492
|
adaptive_lr = self.adaptive_step_transform(adaptive_lr)
|
472
493
|
|
@@ -474,7 +495,7 @@ class NeuralMemory(Module):
|
|
474
495
|
|
475
496
|
decay_factor = self.to_decay_factor(chunked_seq).sigmoid()
|
476
497
|
|
477
|
-
need_layer_lr_mod = exists(self.to_layer_modulation)
|
498
|
+
need_layer_lr_mod = exists(self.to_layer_modulation) and num_chunks > 0
|
478
499
|
has_momentum = exists(self.to_momentum)
|
479
500
|
|
480
501
|
if has_momentum:
|
@@ -505,12 +526,11 @@ class NeuralMemory(Module):
|
|
505
526
|
|
506
527
|
# flatten batch and time if surprise depends on previous layer memory model
|
507
528
|
|
508
|
-
|
509
|
-
weights_for_surprise = weights_for_surprise.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
529
|
+
weights_for_surprise = rearrange_dict_values(weights_for_surprise, 'b n ... -> (b n) ...')
|
510
530
|
|
511
531
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
512
532
|
|
513
|
-
grads
|
533
|
+
grads = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
|
514
534
|
|
515
535
|
grads = TensorDict(grads)
|
516
536
|
|
@@ -521,7 +541,7 @@ class NeuralMemory(Module):
|
|
521
541
|
|
522
542
|
# restore batch and sequence dimension
|
523
543
|
|
524
|
-
grads = grads
|
544
|
+
grads = rearrange_dict_values(grads, '(b n) ... -> b n ...', b = batch * heads)
|
525
545
|
|
526
546
|
# maybe per layer modulation
|
527
547
|
|
@@ -535,19 +555,25 @@ class NeuralMemory(Module):
|
|
535
555
|
# past states
|
536
556
|
|
537
557
|
if not exists(past_state):
|
538
|
-
empty_dict = {key: None for key in weights.keys()}
|
539
|
-
|
540
558
|
# minibatch_init_weight corresponds to W0 in figure 7 of TTT paper
|
541
559
|
|
542
560
|
minibatch_init_weight = weights
|
561
|
+
init_momentum = self.init_momentum(batch)
|
543
562
|
|
544
|
-
|
545
|
-
minibatch_init_weight = weights.apply(lambda t: repeat(t, '... -> b 1 (...)', b = batch * heads))
|
546
|
-
|
547
|
-
past_state = (minibatch_init_weight, empty_dict)
|
563
|
+
past_state = (minibatch_init_weight, init_momentum)
|
548
564
|
|
549
565
|
past_last_update, past_last_momentum = past_state
|
550
566
|
|
567
|
+
# early return if sequence length less than chunk size
|
568
|
+
|
569
|
+
if num_chunks == 0:
|
570
|
+
updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...')
|
571
|
+
next_store_state = NeuralMemCache(seq_len, weights, remainder, past_state, updates)
|
572
|
+
|
573
|
+
output = (updates, next_store_state)
|
574
|
+
|
575
|
+
return output
|
576
|
+
|
551
577
|
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
552
578
|
|
553
579
|
next_momentum = TensorDict() if has_momentum else None
|
@@ -558,8 +584,6 @@ class NeuralMemory(Module):
|
|
558
584
|
|
559
585
|
for (param_name, surprise), (_, last_update), (_, last_momentum) in zip(surprises.items(), past_last_update.items(), past_last_momentum.items()):
|
560
586
|
|
561
|
-
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
|
562
|
-
|
563
587
|
update = surprise
|
564
588
|
|
565
589
|
# derive momentum with associative scan - eq (10)
|
@@ -571,62 +595,52 @@ class NeuralMemory(Module):
|
|
571
595
|
|
572
596
|
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
573
597
|
|
574
|
-
update = self.assoc_scan(1. - decay_factor, update, prev = last_update)
|
598
|
+
update = self.assoc_scan(1. - decay_factor, update, prev = last_update, remove_prev = False)
|
575
599
|
next_last_update[param_name] = update[:, -1]
|
576
600
|
|
577
|
-
updates[param_name] =
|
601
|
+
updates[param_name] = update
|
578
602
|
|
579
603
|
if has_momentum:
|
580
|
-
next_momentum[param_name] =
|
604
|
+
next_momentum[param_name] = momentum
|
581
605
|
|
582
606
|
# determine next state for the storing of memories
|
583
607
|
|
584
608
|
next_state = (next_last_update, next_last_momentum)
|
585
609
|
|
586
|
-
next_store_state = NeuralMemCache(seq_len, remainder, next_state, updates)
|
610
|
+
next_store_state = NeuralMemCache(seq_len, weights, remainder, next_state, updates)
|
587
611
|
|
588
612
|
# returns
|
589
613
|
|
590
614
|
output = (updates, next_store_state)
|
591
615
|
|
592
|
-
|
593
|
-
return output
|
594
|
-
|
595
|
-
return output, aux_kv_recon_loss.mean()
|
616
|
+
return output
|
596
617
|
|
597
618
|
def retrieve_memories(
|
598
619
|
self,
|
599
620
|
seq,
|
600
621
|
past_weights: dict[str, Tensor],
|
601
622
|
chunk_size = None,
|
602
|
-
prev_layer_updates: dict[str, Tensor] | None = None
|
603
623
|
):
|
604
624
|
chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
605
625
|
batch, seq_len = seq.shape[:2]
|
606
626
|
|
607
627
|
seq = self.retrieve_norm(seq)
|
608
628
|
|
609
|
-
assert seq_len >= chunk_size, 'must be handled outside of retrieve'
|
610
|
-
|
611
629
|
needs_pad = chunk_size > 1
|
612
630
|
|
613
|
-
|
614
|
-
|
615
|
-
seq_len_plus_one = seq.shape[-2]
|
631
|
+
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
632
|
+
seq_len_plus_one = seq.shape[-2]
|
616
633
|
|
617
|
-
|
634
|
+
next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
|
618
635
|
|
619
|
-
|
620
|
-
|
636
|
+
padding = next_seq_len - seq_len_plus_one
|
637
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
621
638
|
|
622
639
|
# the parameters of the memory model stores the memories of the key / values
|
623
640
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
624
641
|
|
625
642
|
curr_weights = TensorDict(past_weights)
|
626
643
|
|
627
|
-
if exists(prev_layer_updates):
|
628
|
-
curr_weights = curr_weights + TensorDict(prev_layer_updates)
|
629
|
-
|
630
644
|
# sequence Float['b n d'] to queries
|
631
645
|
|
632
646
|
queries = self.to_queries(seq)
|
@@ -642,7 +656,7 @@ class NeuralMemory(Module):
|
|
642
656
|
# fetch values from memory model
|
643
657
|
|
644
658
|
if dict_get_shape(curr_weights) != self.init_weight_shape:
|
645
|
-
curr_weights = curr_weights
|
659
|
+
curr_weights = rearrange_dict_values(curr_weights, 'b n ... -> (b n) ...')
|
646
660
|
|
647
661
|
queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
|
648
662
|
|
@@ -669,8 +683,7 @@ class NeuralMemory(Module):
|
|
669
683
|
|
670
684
|
# restore, pad with empty memory embed
|
671
685
|
|
672
|
-
|
673
|
-
values = values[:, 1:(seq_len + 1)]
|
686
|
+
values = values[:, 1:(seq_len + 1)]
|
674
687
|
|
675
688
|
return values
|
676
689
|
|
@@ -679,15 +692,14 @@ class NeuralMemory(Module):
|
|
679
692
|
self,
|
680
693
|
token: Tensor,
|
681
694
|
state = None,
|
682
|
-
prev_layer_updates: dict[str, Tensor] | None = None,
|
683
695
|
):
|
684
696
|
|
685
697
|
# unpack previous state
|
686
698
|
|
687
699
|
if not exists(state):
|
688
|
-
state = (0, None, None, None)
|
700
|
+
state = (0, None, None, None, None)
|
689
701
|
|
690
|
-
seq_index, cache_store_seq, past_states, updates = state
|
702
|
+
seq_index, weights, cache_store_seq, past_states, updates = state
|
691
703
|
|
692
704
|
curr_seq_len = seq_index + 1
|
693
705
|
batch = token.shape[0]
|
@@ -695,10 +707,6 @@ class NeuralMemory(Module):
|
|
695
707
|
if token.ndim == 2:
|
696
708
|
token = rearrange(token, 'b d -> b 1 d')
|
697
709
|
|
698
|
-
# get memory model weights
|
699
|
-
|
700
|
-
weights = self.init_weights()
|
701
|
-
|
702
710
|
# increment the sequence cache which is at most the chunk size
|
703
711
|
|
704
712
|
cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
|
@@ -708,7 +716,7 @@ class NeuralMemory(Module):
|
|
708
716
|
if curr_seq_len < self.chunk_size:
|
709
717
|
retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
|
710
718
|
|
711
|
-
output = retrieve, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
|
719
|
+
output = retrieve, NeuralMemCache(curr_seq_len, weights, cache_store_seq, past_states, updates)
|
712
720
|
|
713
721
|
return output
|
714
722
|
|
@@ -719,21 +727,16 @@ class NeuralMemory(Module):
|
|
719
727
|
|
720
728
|
if not exists(updates):
|
721
729
|
updates = weights.clone().zero_()
|
722
|
-
updates = updates
|
730
|
+
updates = repeat_dict_values(updates, '... -> b 1 ...', b = batch)
|
723
731
|
else:
|
724
732
|
updates = updates.apply(lambda t: t[:, -1:])
|
725
733
|
|
726
|
-
if exists(prev_layer_updates):
|
727
|
-
prev_layer_updates = TensorDict(prev_layer_updates)
|
728
|
-
prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
|
729
|
-
|
730
734
|
if store_seq_cache_len == self.chunk_size:
|
731
735
|
|
732
736
|
next_updates, store_state = self.store_memories(
|
733
737
|
cache_store_seq,
|
734
738
|
weights,
|
735
739
|
past_state = past_states,
|
736
|
-
prev_layer_updates = prev_layer_updates,
|
737
740
|
)
|
738
741
|
|
739
742
|
updates = next_updates
|
@@ -742,11 +745,11 @@ class NeuralMemory(Module):
|
|
742
745
|
|
743
746
|
# retrieve
|
744
747
|
|
745
|
-
retrieved = self.retrieve_memories(token,
|
748
|
+
retrieved = self.retrieve_memories(token, updates, chunk_size = 1)
|
746
749
|
|
747
750
|
# next state tuple
|
748
751
|
|
749
|
-
next_store_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
|
752
|
+
next_store_state = NeuralMemCache(curr_seq_len, weights, cache_store_seq, next_states, updates)
|
750
753
|
|
751
754
|
return retrieved, next_store_state
|
752
755
|
|
@@ -756,62 +759,30 @@ class NeuralMemory(Module):
|
|
756
759
|
store_seq = None,
|
757
760
|
mem_model_weights: dict[str, Tensor] | None = None,
|
758
761
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
759
|
-
return_aux_kv_loss = False,
|
760
762
|
chunk_size = None,
|
761
763
|
store_chunk_size = None,
|
762
764
|
return_next_state = False,
|
763
|
-
prev_layer_updates: dict[str, Tensor] | None = None
|
764
765
|
):
|
765
766
|
batch, seq_len = seq.shape[:2]
|
766
767
|
|
767
|
-
if not exists(mem_model_weights):
|
768
|
-
mem_model_weights = self.init_weights()
|
769
|
-
|
770
|
-
if seq_len < self.retrieve_chunk_size:
|
771
|
-
retrieved = self.retrieve_memories(seq, mem_model_weights, chunk_size = 1)
|
772
|
-
|
773
|
-
next_store_state = NeuralMemCache(seq_len, seq, None, None)
|
774
|
-
|
775
|
-
out = (retrieved, next_store_state)
|
776
|
-
|
777
|
-
if not return_aux_kv_loss:
|
778
|
-
return out
|
779
|
-
|
780
|
-
return out, self.zero
|
781
|
-
|
782
768
|
# store
|
783
769
|
|
784
770
|
store_seq = default(store_seq, seq)
|
785
771
|
|
786
|
-
|
772
|
+
updates, next_store_state = self.store_memories(
|
787
773
|
store_seq,
|
788
774
|
mem_model_weights,
|
789
775
|
chunk_size = store_chunk_size,
|
790
|
-
prev_layer_updates = prev_layer_updates,
|
791
|
-
return_aux_kv_loss = True
|
792
776
|
)
|
793
777
|
|
794
778
|
# retrieve
|
795
779
|
|
796
|
-
retrieve_chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
797
|
-
|
798
|
-
if retrieve_chunk_size != 1:
|
799
|
-
if exists(prev_layer_updates):
|
800
|
-
prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
801
|
-
|
802
|
-
updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
803
|
-
|
804
|
-
|
805
780
|
retrieved = self.retrieve_memories(
|
806
781
|
seq,
|
807
|
-
|
782
|
+
updates,
|
808
783
|
chunk_size = chunk_size,
|
809
|
-
prev_layer_updates = prev_layer_updates
|
810
784
|
)
|
811
785
|
|
812
786
|
output = (retrieved, next_store_state)
|
813
787
|
|
814
|
-
|
815
|
-
return output
|
816
|
-
|
817
|
-
return output, aux_kv_recon_loss
|
788
|
+
return output
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=dmS37yBN0j9OqoMCsojuIPfT1EXLN8ackRdZwPb8xDY,24463
|
4
|
+
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
+
titans_pytorch/neural_memory.py,sha256=kc-cV7dK3WhdqRfOCrPW91nA0F56jUK94TE1irckQ34,23487
|
6
|
+
titans_pytorch-0.2.10.dist-info/METADATA,sha256=k7u9eQDNAWG3QqzGqhcdN21D6LYWWRWdd5wZFb560q0,6812
|
7
|
+
titans_pytorch-0.2.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.2.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.2.10.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
|
4
|
-
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
-
titans_pytorch/neural_memory.py,sha256=WAeR-nOpy1XbBP590By1-tCgirulqPbFGut4H1B77-g,24910
|
6
|
-
titans_pytorch-0.2.8.dist-info/METADATA,sha256=4fLUv34KqloeYMWjHBUmp-3iEw0Xq47fjRrwlkyTEsM,6811
|
7
|
-
titans_pytorch-0.2.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.2.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.2.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|