titans-pytorch 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -481,6 +481,7 @@ class MemoryAsContextTransformer(Module):
481
481
  neural_memory_add_value_residual = False,
482
482
  num_longterm_mem_tokens = 0,
483
483
  num_persist_mem_tokens = 0,
484
+ neural_memory_batch_size = None,
484
485
  dim_head = 64,
485
486
  heads = 8,
486
487
  ff_mult = 4,
@@ -488,11 +489,8 @@ class MemoryAsContextTransformer(Module):
488
489
  neural_memory_model: Module | None = None,
489
490
  neural_memory_kwargs: dict = dict(),
490
491
  neural_memory_layers: tuple[int, ...] | None = None,
491
- aux_kv_recon_loss_weight = 1.,
492
492
  use_flex_attn = False,
493
493
  sliding_window_attn = False,
494
- weight_tie_memory_model = False,
495
- prev_neural_mem_update_for_weights = None
496
494
  ):
497
495
  super().__init__()
498
496
 
@@ -526,16 +524,6 @@ class MemoryAsContextTransformer(Module):
526
524
 
527
525
  neural_memory_layers = default(neural_memory_layers, layers)
528
526
 
529
- # weight tying neural memory model
530
-
531
- maybe_copy = deepcopy if not weight_tie_memory_model else identity
532
-
533
- if weight_tie_memory_model:
534
- assert exists(neural_memory_model), '`neural_memory_model` must be explicitly set'
535
-
536
- self.weight_tie_memory_model = weight_tie_memory_model
537
- self.prev_neural_mem_update_for_weights = default(prev_neural_mem_update_for_weights, weight_tie_memory_model)
538
-
539
527
  # mem, attn, and feedforward layers
540
528
 
541
529
  for layer in layers:
@@ -564,7 +552,8 @@ class MemoryAsContextTransformer(Module):
564
552
  mem = NeuralMemory(
565
553
  dim = dim,
566
554
  chunk_size = self.neural_memory_segment_len,
567
- model = maybe_copy(neural_memory_model),
555
+ batch_size = neural_memory_batch_size,
556
+ model = deepcopy(neural_memory_model),
568
557
  **neural_memory_kwargs
569
558
  )
570
559
 
@@ -585,10 +574,7 @@ class MemoryAsContextTransformer(Module):
585
574
 
586
575
  self.gate_attn_output = neural_mem_gate_attn_output
587
576
 
588
- # auxiliary loss on kv recon
589
-
590
- self.has_aux_kv_recon_loss = aux_kv_recon_loss_weight > 0.
591
- self.aux_kv_recon_loss_weight = aux_kv_recon_loss_weight
577
+ # zero for maybe aux loss + device
592
578
 
593
579
  self.register_buffer('zero', torch.tensor(0.), persistent = False)
594
580
 
@@ -696,7 +682,7 @@ class MemoryAsContextTransformer(Module):
696
682
 
697
683
  # math
698
684
 
699
- batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, prev_neural_mem_update_for_weights = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.prev_neural_mem_update_for_weights
685
+ batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
700
686
 
701
687
  seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
702
688
 
@@ -749,18 +735,10 @@ class MemoryAsContextTransformer(Module):
749
735
  next_kv_caches = []
750
736
  next_neural_mem_caches = []
751
737
 
752
- # weight tied neural memory
753
-
754
- neural_memory_updates = None
755
-
756
738
  # value residual
757
739
 
758
740
  value_residual = None
759
741
 
760
- # aux losses
761
-
762
- kv_recon_losses = self.zero
763
-
764
742
  # when inferencing, only do one token at a time
765
743
 
766
744
  if is_inferencing:
@@ -784,24 +762,16 @@ class MemoryAsContextTransformer(Module):
784
762
  mem_input, add_residual = mem_hyper_conn(x)
785
763
 
786
764
  if not is_inferencing:
787
- (retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
788
- mem_input,
789
- return_aux_kv_loss = True,
790
- prev_layer_updates = neural_memory_updates
765
+ retrieved, next_neural_mem_cache = mem(
766
+ mem_input
791
767
  )
792
768
 
793
- kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
794
-
795
769
  else:
796
770
  (retrieved, next_neural_mem_cache) = mem.forward_inference(
797
771
  mem_input,
798
772
  state = next(neural_mem_caches, None),
799
- prev_layer_updates = neural_memory_updates
800
773
  )
801
774
 
802
- if prev_neural_mem_update_for_weights:
803
- neural_memory_updates = next_neural_mem_cache.updates
804
-
805
775
  if self.gate_attn_output:
806
776
  attn_out_gates = retrieved.sigmoid()
807
777
  else:
@@ -883,14 +853,4 @@ class MemoryAsContextTransformer(Module):
883
853
 
884
854
  return logits, next_cache
885
855
 
886
- ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
887
-
888
- losses = ar_loss
889
-
890
- if self.has_aux_kv_recon_loss:
891
- losses = losses + kv_recon_losses * self.aux_kv_recon_loss_weight
892
-
893
- if not return_loss_breakdown:
894
- return losses
895
-
896
- return losses, (ar_loss, kv_recon_losses)
856
+ return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
@@ -6,7 +6,7 @@ from functools import partial
6
6
  from collections import namedtuple
7
7
 
8
8
  import torch
9
- from torch import nn, cat, Tensor
9
+ from torch import nn, cat, tensor, Tensor
10
10
  import torch.nn.functional as F
11
11
  from torch.nn import Linear, Module, Parameter, ParameterList
12
12
  from torch.func import functional_call, vmap, grad
@@ -38,7 +38,13 @@ w - num memory network weight parameters
38
38
 
39
39
  LinearNoBias = partial(Linear, bias = False)
40
40
 
41
- NeuralMemCache = namedtuple('NeuralMemCache', ['seq', 'cache_store_segment', 'states', 'updates'])
41
+ NeuralMemCache = namedtuple('NeuralMemCache', [
42
+ 'seq_index',
43
+ 'weights',
44
+ 'cache_store_segment',
45
+ 'states',
46
+ 'updates',
47
+ ])
42
48
 
43
49
  # functions
44
50
 
@@ -57,6 +63,9 @@ def identity(t):
57
63
  def xnor(x, y):
58
64
  return not (x ^ y)
59
65
 
66
+ def divisible_by(num, den):
67
+ return (num % den) == 0
68
+
60
69
  def safe_cat(inputs, dim = -2):
61
70
  inputs = tuple(filter(exists, inputs))
62
71
 
@@ -67,9 +76,18 @@ def safe_cat(inputs, dim = -2):
67
76
 
68
77
  return cat(inputs, dim = dim)
69
78
 
79
+ def is_empty_tensor(t):
80
+ return t.numel() == 0
81
+
70
82
  def dict_get_shape(td):
71
83
  return {k: v.shape for k, v in td.items()}
72
84
 
85
+ def rearrange_dict_values(td, pattern, **kwargs):
86
+ return td.apply(lambda t: rearrange(t, pattern, **kwargs))
87
+
88
+ def repeat_dict_values(td, pattern, **kwargs):
89
+ return td.apply(lambda t: repeat(t, pattern, **kwargs))
90
+
73
91
  def pair(v):
74
92
  return (v, v) if not isinstance(v, tuple) else v
75
93
 
@@ -106,6 +124,9 @@ def softclamp_max(t, max_value):
106
124
  return ((t / half_max_value).tanh() * half_max_value) + half_max_value
107
125
 
108
126
  def softclamp_grad_norm(t, max_value):
127
+ if is_empty_tensor(t):
128
+ return t
129
+
109
130
  t, inverse = pack_one_with_inverse(t, 'bn *')
110
131
 
111
132
  norm = t.norm(dim = -1, keepdim = True)
@@ -195,6 +216,12 @@ class AssocScan(Module):
195
216
  ):
196
217
  remove_prev = default(remove_prev, exists(prev))
197
218
 
219
+ inputs, inverse_pack_weight_shape = pack_one_with_inverse(inputs, 'b n *')
220
+ gates, _ = pack_one_with_inverse(gates, 'b n *')
221
+
222
+ if exists(prev):
223
+ prev, _ = pack_one_with_inverse(prev, 'b *')
224
+
198
225
  if exists(prev):
199
226
  inputs, _ = pack([prev, inputs], 'b * d')
200
227
  gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
@@ -205,7 +232,7 @@ class AssocScan(Module):
205
232
  if remove_prev:
206
233
  out = out[:, 1:]
207
234
 
208
- return out
235
+ return inverse_pack_weight_shape(out)
209
236
 
210
237
  from accelerated_scan.triton import scan as triton_scan
211
238
  from accelerated_scan.warp import scan as warp_scan
@@ -226,6 +253,7 @@ class AssocScan(Module):
226
253
 
227
254
  outputs = outputs[..., :seq_len]
228
255
  outputs = rearrange(outputs, 'b d n -> b n d')
256
+
229
257
  return outputs
230
258
 
231
259
  out = accelerate_scan_fn(gates, inputs)
@@ -233,7 +261,7 @@ class AssocScan(Module):
233
261
  if remove_prev:
234
262
  out = out[:, 1:]
235
263
 
236
- return out
264
+ return inverse_pack_weight_shape(out)
237
265
 
238
266
  # main neural memory
239
267
 
@@ -248,12 +276,13 @@ class NeuralMemory(Module):
248
276
  self,
249
277
  dim,
250
278
  chunk_size: int | tuple[int, int] = 1,
279
+ batch_size = None,
251
280
  dim_head = None,
252
281
  heads = 1,
253
282
  model: Module | None = None,
254
283
  store_memory_loss_fn: Callable = default_loss_fn,
255
284
  adaptive_step_transform: Callable | None = None,
256
- default_step_transform_max_lr = 1e-2,
285
+ default_step_transform_max_lr = 1.,
257
286
  per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
258
287
  max_mem_layer_modulation = 1e1, # max of 10.
259
288
  attn_pool_chunks = False,
@@ -274,6 +303,13 @@ class NeuralMemory(Module):
274
303
 
275
304
  self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
276
305
 
306
+ # batch size
307
+
308
+ if exists(batch_size):
309
+ assert divisible_by(batch_size, self.store_chunk_size)
310
+
311
+ self.batch_size = batch_size
312
+
277
313
  # associative scan
278
314
 
279
315
  self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan)
@@ -342,14 +378,13 @@ class NeuralMemory(Module):
342
378
  pred = functional_call(self.memory_model, params, inputs)
343
379
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
344
380
  weighted_loss = loss * loss_weights
345
- return weighted_loss.sum(), weighted_loss.mean()
381
+ return weighted_loss.sum()
346
382
 
347
383
  # two functions
348
384
 
349
- grad_fn = grad(forward_and_loss, has_aux = True)
385
+ grad_fn = grad(forward_and_loss)
350
386
 
351
- self.per_sample_grad_fn = vmap(grad_fn, in_dims = (None, 0, 0, 0))
352
- self.per_sample_grad_fn_expanded_weights = vmap(grad_fn, in_dims = (0,) * 4)
387
+ self.per_sample_grad_fn = vmap(grad_fn, in_dims = (0, 0, 0, 0))
353
388
 
354
389
  # queries for retrieving from the model
355
390
 
@@ -417,56 +452,58 @@ class NeuralMemory(Module):
417
452
 
418
453
  self.register_buffer('zero', torch.tensor(0.), persistent = False)
419
454
 
420
- def init_weights(self):
455
+ def init_weights(
456
+ self,
457
+ batch,
458
+ ):
421
459
  weights = TensorDict(dict(self.memory_model.named_parameters()))
460
+ weights = repeat_dict_values(weights, '... -> bh ...', bh = batch * self.heads)
422
461
  return weights
423
462
 
463
+ def init_momentum(
464
+ self,
465
+ batch,
466
+ ):
467
+ weights = TensorDict(dict(self.memory_model.named_parameters()))
468
+ zeros = weights.clone().zero_()
469
+ zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
470
+ return zeros
471
+
424
472
  def store_memories(
425
473
  self,
426
474
  seq,
427
- weights: dict[str, Tensor],
475
+ weights: dict[str, Tensor] | None = None,
428
476
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
429
- prev_layer_updates: dict[str, Tensor] | None = None,
430
- return_aux_kv_loss = False,
431
- chunk_size = None,
477
+ seq_index = 0
432
478
  ):
433
- seq_len, heads, chunk_size = seq.shape[-2], self.heads, default(chunk_size, self.store_chunk_size)
434
-
435
- # handle edge case
436
-
437
- if seq_len < chunk_size:
438
- return TensorDict(weights).clone().zero_(), self.zero
439
-
440
- seq = self.store_norm(seq)
479
+ batch, seq_len, heads, chunk_size = *seq.shape[:2], self.heads, self.store_chunk_size
441
480
 
442
481
  # curtail sequence by multiple of the chunk size
443
482
  # only a complete chunk of the sequence provides the memory for the next chunk
444
483
 
445
484
  round_down_seq_len = round_down_multiple(seq_len, chunk_size)
485
+ num_chunks = round_down_seq_len // chunk_size
446
486
 
447
487
  seq, remainder = seq[:, :round_down_seq_len], seq[:, round_down_seq_len:]
448
488
 
449
- # per sample grad function
450
-
451
- per_sample_grad_fn = self.per_sample_grad_fn
489
+ next_seq_len_index = seq_index + round_down_seq_len
452
490
 
491
+ # init weights if needed
453
492
  # weights of the memory network
454
493
 
494
+ if not exists(weights):
495
+ weights = self.init_weights(batch)
496
+
455
497
  weights = TensorDict(weights)
456
498
 
457
499
  # allow for neural memory of a previous layer to influence surprise of current layer
458
500
 
459
- weights_for_surprise = weights
460
-
461
- if exists(prev_layer_updates):
462
- prev_layer_updates = TensorDict(prev_layer_updates)
463
-
464
- weights_for_surprise = weights_for_surprise + prev_layer_updates
465
-
466
- per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
501
+ weights_for_surprise = repeat_dict_values(weights, 'b ... -> b n ...', n = num_chunks)
467
502
 
468
503
  # derive learned hparams for optimization of memory network
469
504
 
505
+ seq = self.store_norm(seq)
506
+
470
507
  adaptive_lr = self.to_adaptive_step(seq)
471
508
  adaptive_lr = self.adaptive_step_transform(adaptive_lr)
472
509
 
@@ -474,7 +511,7 @@ class NeuralMemory(Module):
474
511
 
475
512
  decay_factor = self.to_decay_factor(chunked_seq).sigmoid()
476
513
 
477
- need_layer_lr_mod = exists(self.to_layer_modulation)
514
+ need_layer_lr_mod = exists(self.to_layer_modulation) and num_chunks > 0
478
515
  has_momentum = exists(self.to_momentum)
479
516
 
480
517
  if has_momentum:
@@ -505,12 +542,11 @@ class NeuralMemory(Module):
505
542
 
506
543
  # flatten batch and time if surprise depends on previous layer memory model
507
544
 
508
- if exists(prev_layer_updates):
509
- weights_for_surprise = weights_for_surprise.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
545
+ weights_for_surprise = rearrange_dict_values(weights_for_surprise, 'b n ... -> (b n) ...')
510
546
 
511
547
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
512
548
 
513
- grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
549
+ grads = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
514
550
 
515
551
  grads = TensorDict(grads)
516
552
 
@@ -521,7 +557,7 @@ class NeuralMemory(Module):
521
557
 
522
558
  # restore batch and sequence dimension
523
559
 
524
- grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch * heads))
560
+ grads = rearrange_dict_values(grads, '(b n) ... -> b n ...', b = batch * heads)
525
561
 
526
562
  # maybe per layer modulation
527
563
 
@@ -535,19 +571,25 @@ class NeuralMemory(Module):
535
571
  # past states
536
572
 
537
573
  if not exists(past_state):
538
- empty_dict = {key: None for key in weights.keys()}
539
-
540
574
  # minibatch_init_weight corresponds to W0 in figure 7 of TTT paper
541
575
 
542
576
  minibatch_init_weight = weights
577
+ init_momentum = self.init_momentum(batch)
543
578
 
544
- if dict_get_shape(weights) == self.init_weight_shape:
545
- minibatch_init_weight = weights.apply(lambda t: repeat(t, '... -> b 1 (...)', b = batch * heads))
546
-
547
- past_state = (minibatch_init_weight, empty_dict)
579
+ past_state = (minibatch_init_weight, init_momentum)
548
580
 
549
581
  past_last_update, past_last_momentum = past_state
550
582
 
583
+ # early return if sequence length less than chunk size
584
+
585
+ if num_chunks == 0:
586
+ updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...')
587
+ next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, past_state, updates)
588
+
589
+ output = (updates, next_store_state)
590
+
591
+ return output
592
+
551
593
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
552
594
 
553
595
  next_momentum = TensorDict() if has_momentum else None
@@ -558,8 +600,6 @@ class NeuralMemory(Module):
558
600
 
559
601
  for (param_name, surprise), (_, last_update), (_, last_momentum) in zip(surprises.items(), past_last_update.items(), past_last_momentum.items()):
560
602
 
561
- surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
562
-
563
603
  update = surprise
564
604
 
565
605
  # derive momentum with associative scan - eq (10)
@@ -571,62 +611,51 @@ class NeuralMemory(Module):
571
611
 
572
612
  # use associative scan again for learned forgetting (weight decay) - eq (13)
573
613
 
574
- update = self.assoc_scan(1. - decay_factor, update, prev = last_update)
614
+ update = self.assoc_scan(1. - decay_factor, update, prev = last_update, remove_prev = False)
575
615
  next_last_update[param_name] = update[:, -1]
576
616
 
577
- updates[param_name] = inverse_pack(update)
617
+ updates[param_name] = update
578
618
 
579
619
  if has_momentum:
580
- next_momentum[param_name] = inverse_pack(momentum)
620
+ next_momentum[param_name] = momentum
581
621
 
582
622
  # determine next state for the storing of memories
583
623
 
584
624
  next_state = (next_last_update, next_last_momentum)
585
625
 
586
- next_store_state = NeuralMemCache(seq_len, remainder, next_state, updates)
626
+ next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, next_state, updates)
587
627
 
588
628
  # returns
589
629
 
590
630
  output = (updates, next_store_state)
591
631
 
592
- if not return_aux_kv_loss:
593
- return output
594
-
595
- return output, aux_kv_recon_loss.mean()
632
+ return output
596
633
 
597
634
  def retrieve_memories(
598
635
  self,
599
636
  seq,
600
637
  past_weights: dict[str, Tensor],
601
- chunk_size = None,
602
- prev_layer_updates: dict[str, Tensor] | None = None
603
638
  ):
604
- chunk_size = default(chunk_size, self.retrieve_chunk_size)
639
+ chunk_size = self.retrieve_chunk_size
605
640
  batch, seq_len = seq.shape[:2]
606
641
 
607
642
  seq = self.retrieve_norm(seq)
608
643
 
609
- assert seq_len >= chunk_size, 'must be handled outside of retrieve'
610
-
611
644
  needs_pad = chunk_size > 1
612
645
 
613
- if needs_pad:
614
- seq = pad_at_dim(seq, (1, 0), dim = 1)
615
- seq_len_plus_one = seq.shape[-2]
646
+ seq = pad_at_dim(seq, (1, 0), dim = 1)
647
+ seq_len_plus_one = seq.shape[-2]
616
648
 
617
- next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
649
+ next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
618
650
 
619
- padding = next_seq_len - seq_len_plus_one
620
- seq = pad_at_dim(seq, (0, padding), dim = 1)
651
+ padding = next_seq_len - seq_len_plus_one
652
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
621
653
 
622
654
  # the parameters of the memory model stores the memories of the key / values
623
655
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
624
656
 
625
657
  curr_weights = TensorDict(past_weights)
626
658
 
627
- if exists(prev_layer_updates):
628
- curr_weights = curr_weights + TensorDict(prev_layer_updates)
629
-
630
659
  # sequence Float['b n d'] to queries
631
660
 
632
661
  queries = self.to_queries(seq)
@@ -642,7 +671,7 @@ class NeuralMemory(Module):
642
671
  # fetch values from memory model
643
672
 
644
673
  if dict_get_shape(curr_weights) != self.init_weight_shape:
645
- curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
674
+ curr_weights = rearrange_dict_values(curr_weights, 'b n ... -> (b n) ...')
646
675
 
647
676
  queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
648
677
 
@@ -669,8 +698,7 @@ class NeuralMemory(Module):
669
698
 
670
699
  # restore, pad with empty memory embed
671
700
 
672
- if needs_pad:
673
- values = values[:, 1:(seq_len + 1)]
701
+ values = values[:, 1:(seq_len + 1)]
674
702
 
675
703
  return values
676
704
 
@@ -678,16 +706,14 @@ class NeuralMemory(Module):
678
706
  def forward_inference(
679
707
  self,
680
708
  token: Tensor,
681
- state = None,
682
- prev_layer_updates: dict[str, Tensor] | None = None,
709
+ state: NeuralMemCache | None = None,
683
710
  ):
684
-
685
711
  # unpack previous state
686
712
 
687
713
  if not exists(state):
688
- state = (0, None, None, None)
714
+ state = (0, None, None, None, None)
689
715
 
690
- seq_index, cache_store_seq, past_states, updates = state
716
+ seq_index, weights, cache_store_seq, past_states, updates = state
691
717
 
692
718
  curr_seq_len = seq_index + 1
693
719
  batch = token.shape[0]
@@ -695,9 +721,7 @@ class NeuralMemory(Module):
695
721
  if token.ndim == 2:
696
722
  token = rearrange(token, 'b d -> b 1 d')
697
723
 
698
- # get memory model weights
699
-
700
- weights = self.init_weights()
724
+ assert token.shape[1] == 1
701
725
 
702
726
  # increment the sequence cache which is at most the chunk size
703
727
 
@@ -708,7 +732,7 @@ class NeuralMemory(Module):
708
732
  if curr_seq_len < self.chunk_size:
709
733
  retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
710
734
 
711
- output = retrieve, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
735
+ output = retrieve, NeuralMemCache(curr_seq_len, weights, cache_store_seq, past_states, updates)
712
736
 
713
737
  return output
714
738
 
@@ -719,21 +743,16 @@ class NeuralMemory(Module):
719
743
 
720
744
  if not exists(updates):
721
745
  updates = weights.clone().zero_()
722
- updates = updates.apply(lambda t: repeat(t, '... -> b 1 ...', b = batch))
746
+ updates = repeat_dict_values(updates, '... -> b 1 ...', b = batch)
723
747
  else:
724
748
  updates = updates.apply(lambda t: t[:, -1:])
725
749
 
726
- if exists(prev_layer_updates):
727
- prev_layer_updates = TensorDict(prev_layer_updates)
728
- prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
729
-
730
750
  if store_seq_cache_len == self.chunk_size:
731
751
 
732
752
  next_updates, store_state = self.store_memories(
733
753
  cache_store_seq,
734
754
  weights,
735
755
  past_state = past_states,
736
- prev_layer_updates = prev_layer_updates,
737
756
  )
738
757
 
739
758
  updates = next_updates
@@ -746,7 +765,7 @@ class NeuralMemory(Module):
746
765
 
747
766
  # next state tuple
748
767
 
749
- next_store_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
768
+ next_store_state = NeuralMemCache(curr_seq_len, weights, cache_store_seq, next_states, updates)
750
769
 
751
770
  return retrieved, next_store_state
752
771
 
@@ -754,63 +773,99 @@ class NeuralMemory(Module):
754
773
  self,
755
774
  seq,
756
775
  store_seq = None,
757
- mem_model_weights: dict[str, Tensor] | None = None,
758
- past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
759
- return_aux_kv_loss = False,
760
- chunk_size = None,
761
- store_chunk_size = None,
762
- return_next_state = False,
763
- prev_layer_updates: dict[str, Tensor] | None = None
776
+ state: NeuralMemCache | None = None,
764
777
  ):
765
- batch, seq_len = seq.shape[:2]
778
+ if not exists(state):
779
+ state = (0, None, None, None, None)
766
780
 
767
- if not exists(mem_model_weights):
768
- mem_model_weights = self.init_weights()
781
+ seq_index, weights, cache_store_seq, past_state, updates = state
769
782
 
770
- if seq_len < self.retrieve_chunk_size:
771
- retrieved = self.retrieve_memories(seq, mem_model_weights, chunk_size = 1)
783
+ assert not exists(cache_store_seq) or is_empty_tensor(cache_store_seq)
772
784
 
773
- next_store_state = NeuralMemCache(seq_len, seq, None, None)
785
+ # store
774
786
 
775
- out = (retrieved, next_store_state)
787
+ store_seq = default(store_seq, seq)
776
788
 
777
- if not return_aux_kv_loss:
778
- return out
789
+ # functions
779
790
 
780
- return out, self.zero
791
+ # compute split sizes of sequence
792
+ # for now manually update weights to last update at the correct boundaries
781
793
 
782
- # store
794
+ store_seq_len, chunk_size, batch_size = store_seq.shape[-2], self.chunk_size, self.batch_size
783
795
 
784
- store_seq = default(store_seq, seq)
796
+ need_update_weights = exists(batch_size)
785
797
 
786
- (updates, next_store_state), aux_kv_recon_loss = self.store_memories(
787
- store_seq,
788
- mem_model_weights,
789
- chunk_size = store_chunk_size,
790
- prev_layer_updates = prev_layer_updates,
791
- return_aux_kv_loss = True
792
- )
798
+ # determine split sizes and when to update
793
799
 
794
- # retrieve
800
+ if need_update_weights:
801
+ update_after_final_store = divisible_by(seq_index + store_seq_len, batch_size)
802
+
803
+ seq_range = torch.arange(store_seq_len) + seq_index + 1
804
+ batch_boundary = divisible_by(seq_range, batch_size)
805
+
806
+ indices = seq_range[batch_boundary] - seq_index
807
+
808
+ indices = F.pad(indices, (1, 0), value = 0)
809
+
810
+ if indices[-1] != store_seq_len:
811
+ indices = F.pad(indices, (0, 1), value = store_seq_len)
812
+
813
+ split_sizes = (indices[1:] - indices[:-1]).tolist()
814
+
815
+ assert sum(split_sizes) == store_seq_len
816
+ else:
817
+ split_sizes = (store_seq_len,)
818
+ update_after_final_store = False
795
819
 
796
- retrieve_chunk_size = default(chunk_size, self.retrieve_chunk_size)
820
+ # accumulate updates
797
821
 
798
- if retrieve_chunk_size != 1:
799
- if exists(prev_layer_updates):
800
- prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
822
+ updates = None
801
823
 
802
- updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
824
+ def accum_updates(past_updates, future_updates):
825
+ if not exists(past_updates):
826
+ return future_updates
827
+
828
+ return TensorDict({param_name: cat((past_update[:, :-1], future_update), dim = 1) for (param_name, past_update), (_, future_update) in zip(past_updates.items(), future_updates.items())})
829
+
830
+ # loop through chunks of store sequences
831
+
832
+ store_seqs = store_seq.split(split_sizes, dim = -2)
833
+
834
+ for ind, store_seq_chunk in enumerate(store_seqs):
835
+ is_last = ind == (len(store_seqs) - 1)
836
+
837
+ # store
838
+
839
+ next_updates, next_neural_mem_state = self.store_memories(
840
+ store_seq_chunk,
841
+ weights,
842
+ seq_index = seq_index,
843
+ past_state = past_state,
844
+ )
845
+
846
+ seq_index = next_neural_mem_state.seq_index
847
+ past_state = next_neural_mem_state.states
848
+
849
+ updates = accum_updates(updates, next_updates)
850
+
851
+ if is_last and not update_after_final_store:
852
+ continue
853
+
854
+ # update weights once batch size is fulfilled
855
+
856
+ last_update, _ = past_state
857
+
858
+ weights = last_update
859
+
860
+ next_neural_mem_state = list(next_neural_mem_state)
861
+ next_neural_mem_state[1] = last_update
862
+ next_neural_mem_state = NeuralMemCache(*next_neural_mem_state)
863
+
864
+ # retrieve
803
865
 
804
866
  retrieved = self.retrieve_memories(
805
867
  seq,
806
- updates,
807
- chunk_size = chunk_size,
808
- prev_layer_updates = prev_layer_updates
868
+ updates
809
869
  )
810
870
 
811
- output = (retrieved, next_store_state)
812
-
813
- if not return_aux_kv_loss:
814
- return output
815
-
816
- return output, aux_kv_recon_loss
871
+ return retrieved, next_neural_mem_state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.9
3
+ Version: 0.2.11
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=RfJ1SvQH5_4PmlB7g-13wPAqYtCCUJxfmtaL0oBrRCU,24563
4
+ titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
+ titans_pytorch/neural_memory.py,sha256=1wX8dbGENHWk7sfz7IFF1G8KY4U5tsNh3cqSDxTUf2U,26150
6
+ titans_pytorch-0.2.11.dist-info/METADATA,sha256=CMLW5FSamLp0cPhIohOD_yXjCXoxqCPzwJrA0e83vQE,6812
7
+ titans_pytorch-0.2.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.11.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
4
- titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
- titans_pytorch/neural_memory.py,sha256=YVbKl7DYKFWUgCawDTxXIEgJAcl7nq5OaZytmovIl8Q,24899
6
- titans_pytorch-0.2.9.dist-info/METADATA,sha256=fSFt54zXLKB5gRhLTJd9551O0pF2qcYNlR7039yJiD0,6811
7
- titans_pytorch-0.2.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.2.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.2.9.dist-info/RECORD,,