titans-pytorch 0.1.22__tar.gz → 0.1.26__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.22
3
+ Version: 0.1.26
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -78,7 +78,7 @@ from titans_pytorch import NeuralMemory
78
78
 
79
79
  mem = NeuralMemory(
80
80
  dim = 384,
81
- chunk_size = 64
81
+ chunk_size = 64 # set to smaller chunk size for better perf on smaller sequence lengths (but more memory usage)
82
82
  ).cuda()
83
83
 
84
84
  seq = torch.randn(2, 1024, 384).cuda()
@@ -24,7 +24,7 @@ from titans_pytorch import NeuralMemory
24
24
 
25
25
  mem = NeuralMemory(
26
26
  dim = 384,
27
- chunk_size = 64
27
+ chunk_size = 64 # set to smaller chunk size for better perf on smaller sequence lengths (but more memory usage)
28
28
  ).cuda()
29
29
 
30
30
  seq = torch.randn(2, 1024, 384).cuda()
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.22"
3
+ version = "0.1.26"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,3 +1,5 @@
1
+ from contextlib import contextmanager
2
+
1
3
  import torch
2
4
  from torch import nn
3
5
 
@@ -5,12 +7,25 @@ import pytest
5
7
  from titans_pytorch import NeuralMemory
6
8
  from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention, MemoryAsContextTransformer
7
9
 
10
+ # functions
11
+
8
12
  def exists(v):
9
13
  return v is not None
10
14
 
15
+ def diff(x, y):
16
+ return (x - y).abs().amax()
17
+
18
+ @contextmanager
19
+ def torch_default_dtype(dtype):
20
+ prev_dtype = torch.get_default_dtype()
21
+ torch.set_default_dtype(dtype)
22
+ yield
23
+ torch.set_default_dtype(prev_dtype)
24
+
25
+ # main test
26
+
11
27
  @pytest.mark.parametrize('seq_len', (32, 1024, 77))
12
28
  @pytest.mark.parametrize('silu', (False, True))
13
- @pytest.mark.parametrize('learned_mem_model_weights', (False, True))
14
29
  @pytest.mark.parametrize('attn_pool_chunks', (False, True))
15
30
  @pytest.mark.parametrize('momentum', (False, True))
16
31
  @pytest.mark.parametrize('qk_rmsnorm', (False, True))
@@ -19,7 +34,6 @@ def exists(v):
19
34
  def test_titans(
20
35
  seq_len,
21
36
  silu,
22
- learned_mem_model_weights,
23
37
  attn_pool_chunks,
24
38
  momentum,
25
39
  qk_rmsnorm,
@@ -35,7 +49,6 @@ def test_titans(
35
49
  momentum = momentum,
36
50
  qk_rmsnorm = qk_rmsnorm,
37
51
  per_parameter_lr_modulation = per_parameter_lr_modulation,
38
- learned_mem_model_weights = learned_mem_model_weights
39
52
  )
40
53
 
41
54
  seq = torch.randn(2, seq_len, 384)
@@ -111,7 +124,11 @@ def test_mac(
111
124
  assert logits.shape == (1, seq_len, 256)
112
125
 
113
126
  @pytest.mark.parametrize('sliding', (False, True))
114
- def test_mac_sampling(sliding):
127
+ @pytest.mark.parametrize('mem_layers', ((), None, (4,)))
128
+ def test_mac_sampling(
129
+ sliding,
130
+ mem_layers
131
+ ):
115
132
  transformer = MemoryAsContextTransformer(
116
133
  num_tokens = 256,
117
134
  dim = 256,
@@ -120,7 +137,7 @@ def test_mac_sampling(sliding):
120
137
  num_persist_mem_tokens = 4,
121
138
  num_longterm_mem_tokens = 0,
122
139
  sliding_window_attn = sliding,
123
- neural_memory_layers = (),
140
+ neural_memory_layers = mem_layers,
124
141
  neural_mem_gate_attn_output = False
125
142
  )
126
143
 
@@ -133,6 +150,38 @@ def test_mac_sampling(sliding):
133
150
 
134
151
  assert torch.allclose(sampled, sampled_with_cache)
135
152
 
153
+ @pytest.mark.parametrize('seq_len', (2, 64, 256))
154
+ @torch_default_dtype(torch.float64)
155
+ def test_neural_mem_inference(
156
+ seq_len
157
+ ):
158
+ mem = NeuralMemory(
159
+ dim = 384,
160
+ chunk_size = 64,
161
+ )
162
+
163
+ seq = torch.randn(2, seq_len, 384)
164
+ parallel_retrieved = mem(seq)
165
+
166
+ assert seq.shape == parallel_retrieved.shape
167
+
168
+ state = None
169
+ sequential_retrieved = []
170
+
171
+ for ind, token in enumerate(seq.unbind(dim = 1)):
172
+
173
+ one_retrieved, state = mem.forward_inference(
174
+ token,
175
+ seq_index = ind,
176
+ state = state,
177
+ )
178
+
179
+ sequential_retrieved.append(one_retrieved)
180
+
181
+ sequential_retrieved = torch.cat(sequential_retrieved, dim = -2)
182
+
183
+ assert torch.allclose(parallel_retrieved, sequential_retrieved, atol = 1e-6)
184
+
136
185
  @pytest.mark.parametrize('seq_len', (1023, 17))
137
186
  @pytest.mark.parametrize('sliding', (True, False))
138
187
  def test_flex(
@@ -157,3 +206,28 @@ def test_flex(
157
206
  out_non_flex, _ = attn(seq, disable_flex_attn = True)
158
207
 
159
208
  assert torch.allclose(out_flex, out_non_flex, atol = 1e-5)
209
+
210
+ @torch_default_dtype(torch.float64)
211
+ def test_assoc_scan():
212
+ from titans_pytorch.titans import AssocScan
213
+ torch.set_default_dtype(torch.float64)
214
+
215
+ scan = AssocScan()
216
+
217
+ seq_len = 128
218
+ mid_point = seq_len // 2
219
+
220
+ gates = torch.randn(2, seq_len, 512).sigmoid()
221
+ inputs = torch.randn(2, seq_len, 512)
222
+
223
+ output = scan(gates, inputs)
224
+
225
+ gates1, gates2 = gates[:, :mid_point], gates[:, mid_point:]
226
+ inputs1, inputs2 = inputs[:, :mid_point], inputs[:, mid_point:]
227
+
228
+ first_half = scan(gates1, inputs1)
229
+
230
+ second_half = scan(gates2, inputs2, prev = first_half[:, -1])
231
+ assert second_half.shape == inputs2.shape
232
+
233
+ assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-6)
@@ -510,10 +510,7 @@ class MemoryAsContextTransformer(Module):
510
510
 
511
511
  layers = tuple(range(1, depth + 1))
512
512
 
513
- if not exists(neural_memory_layers):
514
- neural_memory_layers = layers if has_longterm_mems else ()
515
-
516
- assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
513
+ neural_memory_layers = default(neural_memory_layers, layers)
517
514
 
518
515
  # mem, attn, and feedforward layers
519
516
 
@@ -535,9 +532,10 @@ class MemoryAsContextTransformer(Module):
535
532
  )
536
533
 
537
534
  mem = None
535
+ mem_hyper_conn = None
538
536
 
539
537
  if layer in neural_memory_layers:
540
- assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
538
+ mem_hyper_conn = init_hyper_conn(dim = dim, add_branch_out_to_residual = not neural_mem_gate_attn_output)
541
539
 
542
540
  mem = NeuralMemory(
543
541
  dim = dim,
@@ -545,10 +543,12 @@ class MemoryAsContextTransformer(Module):
545
543
  **neural_memory_kwargs
546
544
  )
547
545
 
546
+
548
547
  ff = FeedForward(dim = dim, mult = ff_mult)
549
548
 
550
549
  self.layers.append(ModuleList([
551
- init_hyper_conn(dim = dim, branch = mem, add_branch_out_to_residual = not neural_mem_gate_attn_output) if exists(mem) else None,
550
+ mem_hyper_conn,
551
+ mem,
552
552
  init_hyper_conn(dim = dim, branch = attn),
553
553
  init_hyper_conn(dim = dim, branch = ff)
554
554
  ]))
@@ -691,8 +691,18 @@ class MemoryAsContextTransformer(Module):
691
691
  # kv caching
692
692
 
693
693
  is_inferencing = exists(cache)
694
- cache = iter(default(cache, []))
694
+ assert not (is_inferencing and self.num_longterm_mem_tokens > 0)
695
+
696
+ if not exists(cache):
697
+ cache = (None, None)
698
+
699
+ kv_caches, neural_mem_caches = cache
700
+
701
+ kv_caches = iter(default(kv_caches, []))
702
+ neural_mem_caches = iter(default(neural_mem_caches, []))
703
+
695
704
  next_kv_caches = []
705
+ next_neural_mem_caches = []
696
706
 
697
707
  # value residual
698
708
 
@@ -711,21 +721,37 @@ class MemoryAsContextTransformer(Module):
711
721
 
712
722
  x = self.expand_streams(x)
713
723
 
714
- for mem, attn, ff in self.layers:
724
+ for mem_hyper_conn, mem, attn, ff in self.layers:
715
725
 
716
726
  retrieved = None
717
727
  attn_out_gates = None
728
+ next_neural_mem_cache = None
718
729
 
719
730
  # maybe neural memory
720
731
 
721
732
  if exists(mem):
722
- retrieved, mem_kv_aux_loss = mem(x, return_aux_kv_loss = True)
723
- kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
733
+
734
+ mem_input, add_residual = mem_hyper_conn(x)
735
+
736
+ if not is_inferencing:
737
+ retrieved, mem_kv_aux_loss = mem(
738
+ mem_input,
739
+ return_aux_kv_loss = True
740
+ )
741
+
742
+ kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
743
+
744
+ else:
745
+ retrieved, next_neural_mem_cache = mem.forward_inference(
746
+ mem_input,
747
+ seq_index = seq_len - 1,
748
+ state = next(neural_mem_caches, None)
749
+ )
724
750
 
725
751
  if self.gate_attn_output:
726
752
  attn_out_gates = retrieved.sigmoid()
727
753
  else:
728
- seq = retrieved
754
+ x = add_residual(retrieved)
729
755
 
730
756
  # attention
731
757
 
@@ -735,12 +761,15 @@ class MemoryAsContextTransformer(Module):
735
761
  disable_flex_attn = disable_flex_attn,
736
762
  flex_attn_fn = flex_attn_fn,
737
763
  output_gating = attn_out_gates,
738
- cache = next(cache, None)
764
+ cache = next(kv_caches, None)
739
765
  )
740
766
 
741
767
  value_residual = default(value_residual, values)
742
768
 
769
+ # caches
770
+
743
771
  next_kv_caches.append(next_kv_cache)
772
+ next_neural_mem_caches.append(next_neural_mem_cache)
744
773
 
745
774
  # feedforward
746
775
 
@@ -775,7 +804,7 @@ class MemoryAsContextTransformer(Module):
775
804
  if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
776
805
  next_kv_caches = next_kv_caches[..., 0:0, :]
777
806
 
778
- return logits, next_kv_caches
807
+ return logits, (next_kv_caches, next_neural_mem_caches)
779
808
 
780
809
  ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
781
810
 
@@ -44,6 +44,16 @@ def default(v, d):
44
44
  def xnor(x, y):
45
45
  return not (x ^ y)
46
46
 
47
+ def safe_cat(inputs, dim = -2):
48
+ inputs = tuple(filter(exists, inputs))
49
+
50
+ if len(inputs) == 0:
51
+ return None
52
+ elif len(inputs) == 1:
53
+ return inputs[0]
54
+
55
+ return cat(inputs, dim = dim)
56
+
47
57
  def identity(t):
48
58
  return t
49
59
 
@@ -314,11 +324,26 @@ class AssocScan(Module):
314
324
  super().__init__()
315
325
  self.use_accelerated = use_accelerated
316
326
 
317
- def forward(self, gates, inputs):
327
+ def forward(
328
+ self,
329
+ gates,
330
+ inputs,
331
+ prev = None,
332
+ remove_prev = None
333
+ ):
334
+ remove_prev = default(remove_prev, exists(prev))
335
+
336
+ if exists(prev):
337
+ inputs, _ = pack([prev, inputs], 'b * d')
338
+ gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
318
339
 
319
340
  if not self.use_accelerated:
320
- _, outputs = associative_scan(binary_operator, (gates, inputs))
321
- return outputs
341
+ _, out = associative_scan(binary_operator, (gates, inputs))
342
+
343
+ if remove_prev:
344
+ out = out[:, 1:]
345
+
346
+ return out
322
347
 
323
348
  from accelerated_scan.triton import scan as triton_scan
324
349
  from accelerated_scan.warp import scan as warp_scan
@@ -341,7 +366,12 @@ class AssocScan(Module):
341
366
  outputs = rearrange(outputs, 'b d n -> b n d')
342
367
  return outputs
343
368
 
344
- return accelerate_scan_fn(gates, inputs)
369
+ out = accelerate_scan_fn(gates, inputs)
370
+
371
+ if remove_prev:
372
+ out = out[:, 1:]
373
+
374
+ return out
345
375
 
346
376
  # main neural memory
347
377
 
@@ -370,7 +400,6 @@ class NeuralMemory(Module):
370
400
  post_rmsnorm = True,
371
401
  qk_rmsnorm = False,
372
402
  accept_value_residual = False,
373
- learned_mem_model_weights = True,
374
403
  max_grad_norm: float | None = None,
375
404
  use_accelerated_scan = False,
376
405
  activation: Module | None = None,
@@ -418,9 +447,6 @@ class NeuralMemory(Module):
418
447
  if not exists(model):
419
448
  model = MemoryMLP(dim_head, **default_model_kwargs)
420
449
 
421
- if not learned_mem_model_weights:
422
- model.requires_grad_(False)
423
-
424
450
  assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
425
451
 
426
452
  # the memory is the weights of the model
@@ -522,16 +548,9 @@ class NeuralMemory(Module):
522
548
 
523
549
  self.register_buffer('zero', torch.tensor(0.), persistent = False)
524
550
 
525
- def init_weights_and_momentum(self, zero_weights = False):
526
- params = TensorDict(dict(self.memory_model.named_parameters()))
527
-
528
- init_weights = params
529
- init_momentum = params.clone().zero_()
530
-
531
- if zero_weights:
532
- init_weights = params.clone().zero_()
533
-
534
- return init_weights, init_momentum
551
+ def init_weights(self):
552
+ weights = TensorDict(dict(self.memory_model.named_parameters()))
553
+ return weights
535
554
 
536
555
  def init_empty_memory_embed(self, batch, seq_len):
537
556
  return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
@@ -539,7 +558,8 @@ class NeuralMemory(Module):
539
558
  def store_memories(
540
559
  self,
541
560
  seq,
542
- past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
561
+ weights: dict[str, Tensor],
562
+ past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
543
563
  return_aux_kv_loss = False,
544
564
  chunk_size = None,
545
565
  value_residual = None
@@ -551,8 +571,7 @@ class NeuralMemory(Module):
551
571
  # handle edge case
552
572
 
553
573
  if seq_len < chunk_size:
554
- past_weight, _ = past_state
555
- return TensorDict(past_weight).clone().zero_(), self.zero
574
+ return TensorDict(weights).clone().zero_(), self.zero
556
575
 
557
576
  seq = self.store_norm(seq)
558
577
 
@@ -563,10 +582,9 @@ class NeuralMemory(Module):
563
582
 
564
583
  seq = seq[:, :round_down_seq_len]
565
584
 
566
- # get the weights of the memory network
585
+ # weights of the memory network
567
586
 
568
- past_state = tuple(TensorDict(d) for d in past_state)
569
- curr_weights, past_momentum = past_state
587
+ weights = TensorDict(weights)
570
588
 
571
589
  # derive learned hparams for optimization of memory network
572
590
 
@@ -616,7 +634,7 @@ class NeuralMemory(Module):
616
634
 
617
635
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
618
636
 
619
- grads, aux_kv_recon_loss = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
637
+ grads, aux_kv_recon_loss = self.per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
620
638
 
621
639
  grads = TensorDict(grads)
622
640
 
@@ -638,12 +656,23 @@ class NeuralMemory(Module):
638
656
 
639
657
  surprises = grads.apply(lambda t: -t)
640
658
 
659
+ # past states
660
+
661
+ if not exists(past_state):
662
+ empty_dict = {key: None for key in weights.keys()}
663
+ past_state = (empty_dict, empty_dict)
664
+
665
+ past_last_update, past_last_momentum = past_state
666
+
641
667
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
642
668
 
643
669
  next_momentum = TensorDict() if has_momentum else None
644
670
  updates = TensorDict()
645
671
 
646
- for param_name, surprise in surprises.items():
672
+ next_last_update = TensorDict()
673
+ next_last_momentum = TensorDict()
674
+
675
+ for (param_name, surprise), (_, last_update), (_, last_momentum) in zip(surprises.items(), past_last_update.items(), past_last_momentum.items()):
647
676
 
648
677
  surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
649
678
 
@@ -652,23 +681,27 @@ class NeuralMemory(Module):
652
681
  # derive momentum with associative scan - eq (10)
653
682
 
654
683
  if has_momentum:
655
- update = self.assoc_scan(adaptive_momentum, surprise) # momentum is S / surprise in the paper
684
+ update = self.assoc_scan(adaptive_momentum, surprise, prev = last_momentum) # momentum is S / surprise in the paper
656
685
  momentum = update
686
+ next_last_momentum[param_name] = momentum[:, -1]
657
687
 
658
688
  # use associative scan again for learned forgetting (weight decay) - eq (13)
659
689
 
660
- update = self.assoc_scan(1. - decay_factor, update)
690
+ update = self.assoc_scan(1. - decay_factor, update, prev = last_update)
691
+ next_last_update[param_name] = update[:, -1]
661
692
 
662
693
  updates[param_name] = inverse_pack(update)
663
694
 
664
695
  if has_momentum:
665
696
  next_momentum[param_name] = inverse_pack(momentum)
666
697
 
667
- # compute the next weight per batch
698
+ # compute next states for inference, or titans-xl like training
668
699
 
669
- last_update = updates.apply(lambda t: t[:, -1])
700
+ next_state = (next_last_update, next_last_momentum)
670
701
 
671
- output = (updates, orig_values)
702
+ # returns
703
+
704
+ output = (updates, next_state, orig_values)
672
705
 
673
706
  if not return_aux_kv_loss:
674
707
  return output
@@ -678,7 +711,7 @@ class NeuralMemory(Module):
678
711
  def retrieve_memories(
679
712
  self,
680
713
  seq,
681
- past_weights: dict[str, Tensor] | None = None,
714
+ past_weights: dict[str, Tensor],
682
715
  chunk_size = None
683
716
  ):
684
717
  chunk_size = default(chunk_size, self.retrieve_chunk_size)
@@ -700,13 +733,7 @@ class NeuralMemory(Module):
700
733
  # the parameters of the memory model stores the memories of the key / values
701
734
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
702
735
 
703
- curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
704
-
705
- if exists(past_weights):
706
- past_weights = TensorDict(past_weights)
707
- assert past_weights.keys() == curr_weights.keys()
708
-
709
- curr_weights = curr_weights + past_weights
736
+ curr_weights = TensorDict(past_weights)
710
737
 
711
738
  # sequence Float['b n d'] to queries
712
739
 
@@ -753,10 +780,77 @@ class NeuralMemory(Module):
753
780
 
754
781
  return values[:, :seq_len]
755
782
 
783
+ def forward_inference(
784
+ self,
785
+ token: Tensor,
786
+ seq_index = None, # the index of the token in the sequence, starts at 0
787
+ state = None,
788
+ ):
789
+
790
+ # unpack previous state
791
+
792
+ if not exists(state):
793
+ state = (None, None, None)
794
+
795
+ cache_store_seq, past_states, updates = state
796
+
797
+ seq_index = default(seq_index, 0)
798
+ curr_seq_len = seq_index + 1
799
+ batch = token.shape[0]
800
+
801
+ if token.ndim == 2:
802
+ token = rearrange(token, 'b d -> b 1 d')
803
+
804
+ # get memory model weights
805
+
806
+ weights = self.init_weights()
807
+
808
+ # increment the sequence cache which is at most the chunk size
809
+
810
+ cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
811
+
812
+ # early return empty memory, when no memories are stored for steps < first chunk size
813
+
814
+ if curr_seq_len < self.chunk_size:
815
+ empty_mem = self.init_empty_memory_embed(batch, 1)
816
+
817
+ return empty_mem, (cache_store_seq, past_states, updates)
818
+
819
+ # store if storage sequence cache hits the chunk size
820
+
821
+ next_states = past_states
822
+ store_seq_cache_len = cache_store_seq.shape[-2]
823
+
824
+ if not exists(updates):
825
+ updates = weights.clone().zero_()
826
+ updates = updates.apply(lambda t: repeat(t, '... -> b 1 ...', b = batch))
827
+
828
+ if store_seq_cache_len == self.chunk_size:
829
+
830
+ next_updates, next_states, _ = self.store_memories(
831
+ cache_store_seq,
832
+ weights,
833
+ past_state = past_states
834
+ )
835
+
836
+ updates = next_updates
837
+ cache_store_seq = None
838
+
839
+ # retrieve
840
+
841
+ retrieved = self.retrieve_memories(token, updates + weights, chunk_size = 1)
842
+
843
+ # next state tuple
844
+
845
+ next_state = (cache_store_seq, next_states, updates)
846
+
847
+ return retrieved, next_state
848
+
756
849
  def forward(
757
850
  self,
758
851
  seq,
759
852
  store_seq = None,
853
+ mem_model_weights: dict[str, Tensor] | None = None,
760
854
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
761
855
  return_aux_kv_loss = False,
762
856
  chunk_size = None,
@@ -773,20 +867,15 @@ class NeuralMemory(Module):
773
867
 
774
868
  return out, self.zero
775
869
 
776
- if exists(past_state):
777
- past_state = tuple(TensorDict(d) for d in past_state)
778
-
779
- if not exists(past_state):
780
- past_state = self.init_weights_and_momentum()
870
+ if not exists(mem_model_weights):
871
+ mem_model_weights = self.init_weights()
781
872
 
782
873
  store_seq = default(store_seq, seq)
783
874
  store_chunk_size = default(store_chunk_size, chunk_size)
784
875
 
785
- (updates, values), aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
786
-
787
- past_weights, _ = past_state
876
+ (updates, next_state, values), aux_kv_recon_loss = self.store_memories(store_seq, mem_model_weights, chunk_size = store_chunk_size, return_aux_kv_loss = True)
788
877
 
789
- retrieved = self.retrieve_memories(seq, past_weights + updates, chunk_size = chunk_size)
878
+ retrieved = self.retrieve_memories(seq, mem_model_weights + updates, chunk_size = chunk_size)
790
879
 
791
880
  output = retrieved
792
881
 
File without changes