titans-pytorch 0.1.18__tar.gz → 0.1.21__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.18
3
+ Version: 0.1.21
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -196,3 +196,15 @@ $ python train_mac.py
196
196
  year = {2024}
197
197
  }
198
198
  ```
199
+
200
+ ```bibtex
201
+ @misc{wang2025testtimeregressionunifyingframework,
202
+ title = {Test-time regression: a unifying framework for designing sequence models with associative memory},
203
+ author = {Ke Alexander Wang and Jiaxin Shi and Emily B. Fox},
204
+ year = {2025},
205
+ eprint = {2501.12352},
206
+ archivePrefix = {arXiv},
207
+ primaryClass = {cs.LG},
208
+ url = {https://arxiv.org/abs/2501.12352},
209
+ }
210
+ ```
@@ -142,3 +142,15 @@ $ python train_mac.py
142
142
  year = {2024}
143
143
  }
144
144
  ```
145
+
146
+ ```bibtex
147
+ @misc{wang2025testtimeregressionunifyingframework,
148
+ title = {Test-time regression: a unifying framework for designing sequence models with associative memory},
149
+ author = {Ke Alexander Wang and Jiaxin Shi and Emily B. Fox},
150
+ year = {2025},
151
+ eprint = {2501.12352},
152
+ archivePrefix = {arXiv},
153
+ primaryClass = {cs.LG},
154
+ url = {https://arxiv.org/abs/2501.12352},
155
+ }
156
+ ```
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.18"
3
+ version = "0.1.21"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -13,6 +13,7 @@ def exists(v):
13
13
  @pytest.mark.parametrize('learned_mem_model_weights', (False, True))
14
14
  @pytest.mark.parametrize('attn_pool_chunks', (False, True))
15
15
  @pytest.mark.parametrize('momentum', (False, True))
16
+ @pytest.mark.parametrize('qk_rmsnorm', (False, True))
16
17
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
17
18
  @pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
18
19
  def test_titans(
@@ -21,6 +22,7 @@ def test_titans(
21
22
  learned_mem_model_weights,
22
23
  attn_pool_chunks,
23
24
  momentum,
25
+ qk_rmsnorm,
24
26
  max_grad_norm,
25
27
  per_parameter_lr_modulation
26
28
  ):
@@ -31,6 +33,7 @@ def test_titans(
31
33
  attn_pool_chunks = attn_pool_chunks,
32
34
  max_grad_norm = max_grad_norm,
33
35
  momentum = momentum,
36
+ qk_rmsnorm = qk_rmsnorm,
34
37
  per_parameter_lr_modulation = per_parameter_lr_modulation,
35
38
  learned_mem_model_weights = learned_mem_model_weights
36
39
  )
@@ -107,14 +110,18 @@ def test_mac(
107
110
  logits = transformer(x)
108
111
  assert logits.shape == (1, seq_len, 256)
109
112
 
110
- def test_mac_sampling():
113
+ @pytest.mark.parametrize('sliding', (False, True))
114
+ def test_mac_sampling(sliding):
111
115
  transformer = MemoryAsContextTransformer(
112
116
  num_tokens = 256,
113
117
  dim = 256,
114
118
  depth = 2,
115
119
  segment_len = 32,
116
120
  num_persist_mem_tokens = 4,
117
- num_longterm_mem_tokens = 16,
121
+ num_longterm_mem_tokens = 0,
122
+ sliding_window_attn = sliding,
123
+ neural_memory_layers = (),
124
+ neural_mem_gate_attn_output = False
118
125
  )
119
126
 
120
127
  ids = torch.randint(0, 256, (1, 1023))
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
  from typing import Callable
3
+
3
4
  from math import ceil
4
5
  from functools import partial
6
+ from collections import namedtuple
5
7
 
6
8
  import tqdm
7
9
 
8
10
  import torch
9
- from torch import nn, cat
11
+ from torch import nn, stack, cat
10
12
  import torch.nn.functional as F
11
13
  from torch.nn import Module, ModuleList, Linear
12
14
 
@@ -69,6 +71,8 @@ from titans_pytorch.titans import NeuralMemory
69
71
 
70
72
  LinearNoBias = partial(Linear, bias = False)
71
73
 
74
+ AttnIntermediates = namedtuple('AttnIntermediates', ('value_residual', 'cached_key_values'))
75
+
72
76
  # helpers
73
77
 
74
78
  def exists(v):
@@ -80,6 +84,9 @@ def default(v, d):
80
84
  def identity(t):
81
85
  return t
82
86
 
87
+ def divisible_by(num, den):
88
+ return (num % den) == 0
89
+
83
90
  def round_up_multiple(seq, mult):
84
91
  return ceil(seq / mult) * mult
85
92
 
@@ -111,7 +118,7 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
111
118
 
112
119
  def inverse(out):
113
120
  if fold_into_batch:
114
- out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
121
+ out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
115
122
 
116
123
  if needs_pad:
117
124
  out = out[..., :-padding, :]
@@ -213,12 +220,75 @@ class SegmentedAttention(Module):
213
220
  self.segment_len = segment_len
214
221
  self.num_persist_mem_tokens = num_persist_mem_tokens
215
222
 
223
+ def forward_inference(
224
+ self,
225
+ token,
226
+ cache,
227
+ value_residual = None,
228
+ output_gating = None,
229
+ ):
230
+ batch = token.shape[0]
231
+
232
+ # attention
233
+
234
+ token = self.norm(token)
235
+
236
+ q, k, v = self.to_qkv(token).chunk(3, dim = -1)
237
+ q, k, v = map(self.split_heads, (q, k, v))
238
+
239
+ # value residual
240
+
241
+ orig_v = v
242
+
243
+ if exists(self.to_learned_v_mix):
244
+ mix = self.to_learned_v_mix(token)
245
+ v = v.lerp(value_residual, mix)
246
+
247
+ # caching
248
+
249
+ ck, cv = cache
250
+ k = cat((ck, k), dim = -2)
251
+ v = cat((cv, v), dim = -2)
252
+
253
+ next_cache = (k, v)
254
+
255
+ # relative positions
256
+
257
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
258
+
259
+ # fold
260
+
261
+ q, k, v = tuple(rearrange(t, 'b h n d -> b h n d') for t in (q, k, v))
262
+
263
+ # take care of persistent memory key / values
264
+
265
+ pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = k.shape[0])
266
+
267
+ # persistent memory
268
+
269
+ k = cat((pmk, k), dim = -2)
270
+ v = cat((pmv, v), dim = -2)
271
+
272
+ # attention
273
+
274
+ out, _ = self.attend(q, k, v)
275
+
276
+ out = self.merge_heads(out)
277
+
278
+ out = self.to_out(out)
279
+
280
+ if exists(output_gating):
281
+ out = out * output_gating
282
+
283
+ return out, AttnIntermediates(orig_v, next_cache)
284
+
216
285
  def forward_flex(
217
286
  self,
218
287
  seq,
219
288
  value_residual = None,
220
289
  flex_attn_fn: Callable | None = None,
221
- output_gating = None
290
+ output_gating = None,
291
+ cache = None
222
292
  ):
223
293
 
224
294
  assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
@@ -240,6 +310,10 @@ class SegmentedAttention(Module):
240
310
  mix = self.to_learned_v_mix(seq)
241
311
  v = v.lerp(value_residual, mix)
242
312
 
313
+ # caching
314
+
315
+ next_cache = tuple(map(inverse_segment, (k, v)))
316
+
243
317
  # take care of persistent memory key / values
244
318
 
245
319
  pmk, pmv = repeat(self.persistent_memory, 'kv h n d -> kv b h n d', b = batch)
@@ -271,7 +345,7 @@ class SegmentedAttention(Module):
271
345
  if exists(output_gating):
272
346
  out = out * output_gating
273
347
 
274
- return out, orig_v
348
+ return out, AttnIntermediates(orig_v, next_cache)
275
349
 
276
350
  def forward(
277
351
  self,
@@ -279,10 +353,17 @@ class SegmentedAttention(Module):
279
353
  value_residual = None,
280
354
  flex_attn_fn: Callable | None = None,
281
355
  disable_flex_attn = False,
282
- output_gating = None
356
+ output_gating = None,
357
+ cache = None
283
358
  ):
359
+ is_inferencing = exists(cache)
360
+
361
+ if is_inferencing:
362
+ assert seq.shape[-2] == 1
363
+ return self.forward_inference(seq, cache, value_residual, output_gating = output_gating)
364
+
284
365
  if seq.is_cuda and self.use_flex_attn and not disable_flex_attn:
285
- return self.forward_flex(seq, value_residual, flex_attn_fn, output_gating = output_gating)
366
+ return self.forward_flex(seq, value_residual, flex_attn_fn, output_gating = output_gating, cache = cache)
286
367
 
287
368
  assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
288
369
 
@@ -310,6 +391,10 @@ class SegmentedAttention(Module):
310
391
  mix = self.to_learned_v_mix(seq)
311
392
  v = v.lerp(value_residual, mix)
312
393
 
394
+ # caching
395
+
396
+ next_cache = tuple(map(inverse_segment, (k, v)))
397
+
313
398
  # relative positions
314
399
 
315
400
  q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
@@ -369,7 +454,7 @@ class SegmentedAttention(Module):
369
454
  if exists(output_gating):
370
455
  out = out * output_gating
371
456
 
372
- return out, orig_v
457
+ return out, AttnIntermediates(orig_v, next_cache)
373
458
 
374
459
  # MAC transformer
375
460
 
@@ -413,6 +498,7 @@ class MemoryAsContextTransformer(Module):
413
498
  # maybe sliding window attn
414
499
 
415
500
  self.sliding_window_attn = sliding_window_attn
501
+ self.attn_window_size = segment_len + num_longterm_mem_tokens
416
502
 
417
503
  # hyper conection
418
504
 
@@ -487,7 +573,6 @@ class MemoryAsContextTransformer(Module):
487
573
  assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
488
574
  self.use_flex_attn = use_flex_attn
489
575
 
490
- self.segment_len = segment_len
491
576
  self.num_persist_mem_tokens = num_persist_mem_tokens
492
577
 
493
578
  @torch.no_grad()
@@ -569,7 +654,7 @@ class MemoryAsContextTransformer(Module):
569
654
 
570
655
  # math
571
656
 
572
- batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens
657
+ batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
573
658
 
574
659
  # token embedding
575
660
 
@@ -603,6 +688,12 @@ class MemoryAsContextTransformer(Module):
603
688
  block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens, self.sliding_window_attn)
604
689
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
605
690
 
691
+ # kv caching
692
+
693
+ is_inferencing = exists(cache)
694
+ cache = iter(default(cache, []))
695
+ next_kv_caches = []
696
+
606
697
  # value residual
607
698
 
608
699
  value_residual = None
@@ -611,6 +702,11 @@ class MemoryAsContextTransformer(Module):
611
702
 
612
703
  kv_recon_losses = self.zero
613
704
 
705
+ # when inferencing, only do one token at a time
706
+
707
+ if is_inferencing:
708
+ x = x[:, -1:]
709
+
614
710
  # expand and reduce streams for hyper connections
615
711
 
616
712
  x = self.expand_streams(x)
@@ -620,6 +716,8 @@ class MemoryAsContextTransformer(Module):
620
716
  retrieved = None
621
717
  attn_out_gates = None
622
718
 
719
+ # maybe neural memory
720
+
623
721
  if exists(mem):
624
722
  retrieved, mem_kv_aux_loss = mem(x, return_aux_kv_loss = True)
625
723
  kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
@@ -631,16 +729,19 @@ class MemoryAsContextTransformer(Module):
631
729
 
632
730
  # attention
633
731
 
634
- x, values = attn(
732
+ x, (values, next_kv_cache) = attn(
635
733
  x,
636
734
  value_residual = value_residual,
637
735
  disable_flex_attn = disable_flex_attn,
638
736
  flex_attn_fn = flex_attn_fn,
639
- output_gating = attn_out_gates
737
+ output_gating = attn_out_gates,
738
+ cache = next(cache, None)
640
739
  )
641
740
 
642
741
  value_residual = default(value_residual, values)
643
742
 
743
+ next_kv_caches.append(next_kv_cache)
744
+
644
745
  # feedforward
645
746
 
646
747
  x = ff(x)
@@ -665,7 +766,16 @@ class MemoryAsContextTransformer(Module):
665
766
  if not return_cache:
666
767
  return logits
667
768
 
668
- return logits, cache
769
+ next_kv_caches = stack([stack(kv_cache) for kv_cache in next_kv_caches])
770
+
771
+ # handle kv cache length depending on local attention type
772
+
773
+ next_kv_caches = next_kv_caches[..., -attn_window_size:, :]
774
+
775
+ if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
776
+ next_kv_caches = next_kv_caches[..., 0:0, :]
777
+
778
+ return logits, next_kv_caches
669
779
 
670
780
  ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
671
781
 
@@ -365,6 +365,7 @@ class NeuralMemory(Module):
365
365
  momentum = True,
366
366
  pre_rmsnorm = True,
367
367
  post_rmsnorm = True,
368
+ qk_rmsnorm = False,
368
369
  learned_mem_model_weights = True,
369
370
  max_grad_norm: float | None = None,
370
371
  use_accelerated_scan = False,
@@ -389,6 +390,9 @@ class NeuralMemory(Module):
389
390
 
390
391
  self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
391
392
 
393
+ self.q_norm = MultiheadRMSNorm(dim_head, heads) if qk_rmsnorm else nn.Identity()
394
+ self.k_norm = MultiheadRMSNorm(dim_head, heads) if qk_rmsnorm else nn.Identity()
395
+
392
396
  # maybe multi-headed
393
397
 
394
398
  dim_inner = dim_head * heads
@@ -577,6 +581,10 @@ class NeuralMemory(Module):
577
581
 
578
582
  batch = keys.shape[0]
579
583
 
584
+ # maybe qk rmsnorm
585
+
586
+ keys = self.k_norm(keys)
587
+
580
588
  # take care of chunking
581
589
 
582
590
  keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = chunk_size) for t in (keys, values))
@@ -683,6 +691,10 @@ class NeuralMemory(Module):
683
691
 
684
692
  queries = self.split_heads(queries)
685
693
 
694
+ # maybe qk rmsnorm
695
+
696
+ queries = self.q_norm(queries)
697
+
686
698
  # fetch values from memory model
687
699
 
688
700
  curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
File without changes