titans-pytorch 0.1.18__tar.gz → 0.1.20__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.18
3
+ Version: 0.1.20
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.18"
3
+ version = "0.1.20"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -107,14 +107,18 @@ def test_mac(
107
107
  logits = transformer(x)
108
108
  assert logits.shape == (1, seq_len, 256)
109
109
 
110
- def test_mac_sampling():
110
+ @pytest.mark.parametrize('sliding', (False, True))
111
+ def test_mac_sampling(sliding):
111
112
  transformer = MemoryAsContextTransformer(
112
113
  num_tokens = 256,
113
114
  dim = 256,
114
115
  depth = 2,
115
116
  segment_len = 32,
116
117
  num_persist_mem_tokens = 4,
117
- num_longterm_mem_tokens = 16,
118
+ num_longterm_mem_tokens = 0,
119
+ sliding_window_attn = sliding,
120
+ neural_memory_layers = (),
121
+ neural_mem_gate_attn_output = False
118
122
  )
119
123
 
120
124
  ids = torch.randint(0, 256, (1, 1023))
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
  from typing import Callable
3
+
3
4
  from math import ceil
4
5
  from functools import partial
6
+ from collections import namedtuple
5
7
 
6
8
  import tqdm
7
9
 
8
10
  import torch
9
- from torch import nn, cat
11
+ from torch import nn, stack, cat
10
12
  import torch.nn.functional as F
11
13
  from torch.nn import Module, ModuleList, Linear
12
14
 
@@ -69,6 +71,8 @@ from titans_pytorch.titans import NeuralMemory
69
71
 
70
72
  LinearNoBias = partial(Linear, bias = False)
71
73
 
74
+ AttnIntermediates = namedtuple('AttnIntermediates', ('value_residual', 'cached_key_values'))
75
+
72
76
  # helpers
73
77
 
74
78
  def exists(v):
@@ -80,6 +84,9 @@ def default(v, d):
80
84
  def identity(t):
81
85
  return t
82
86
 
87
+ def divisible_by(num, den):
88
+ return (num % den) == 0
89
+
83
90
  def round_up_multiple(seq, mult):
84
91
  return ceil(seq / mult) * mult
85
92
 
@@ -111,7 +118,7 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
111
118
 
112
119
  def inverse(out):
113
120
  if fold_into_batch:
114
- out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
121
+ out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
115
122
 
116
123
  if needs_pad:
117
124
  out = out[..., :-padding, :]
@@ -213,12 +220,75 @@ class SegmentedAttention(Module):
213
220
  self.segment_len = segment_len
214
221
  self.num_persist_mem_tokens = num_persist_mem_tokens
215
222
 
223
+ def forward_inference(
224
+ self,
225
+ token,
226
+ cache,
227
+ value_residual = None,
228
+ output_gating = None,
229
+ ):
230
+ batch = token.shape[0]
231
+
232
+ # attention
233
+
234
+ token = self.norm(token)
235
+
236
+ q, k, v = self.to_qkv(token).chunk(3, dim = -1)
237
+ q, k, v = map(self.split_heads, (q, k, v))
238
+
239
+ # value residual
240
+
241
+ orig_v = v
242
+
243
+ if exists(self.to_learned_v_mix):
244
+ mix = self.to_learned_v_mix(token)
245
+ v = v.lerp(value_residual, mix)
246
+
247
+ # caching
248
+
249
+ ck, cv = cache
250
+ k = cat((ck, k), dim = -2)
251
+ v = cat((cv, v), dim = -2)
252
+
253
+ next_cache = (k, v)
254
+
255
+ # relative positions
256
+
257
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
258
+
259
+ # fold
260
+
261
+ q, k, v = tuple(rearrange(t, 'b h n d -> b h n d') for t in (q, k, v))
262
+
263
+ # take care of persistent memory key / values
264
+
265
+ pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = k.shape[0])
266
+
267
+ # persistent memory
268
+
269
+ k = cat((pmk, k), dim = -2)
270
+ v = cat((pmv, v), dim = -2)
271
+
272
+ # attention
273
+
274
+ out, _ = self.attend(q, k, v)
275
+
276
+ out = self.merge_heads(out)
277
+
278
+ out = self.to_out(out)
279
+
280
+ if exists(output_gating):
281
+ out = out * output_gating
282
+
283
+ return out, AttnIntermediates(orig_v, next_cache)
284
+
216
285
  def forward_flex(
217
286
  self,
218
287
  seq,
219
288
  value_residual = None,
220
289
  flex_attn_fn: Callable | None = None,
221
- output_gating = None
290
+ output_gating = None,
291
+ cache = None
222
292
  ):
223
293
 
224
294
  assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
@@ -240,6 +310,10 @@ class SegmentedAttention(Module):
240
310
  mix = self.to_learned_v_mix(seq)
241
311
  v = v.lerp(value_residual, mix)
242
312
 
313
+ # caching
314
+
315
+ next_cache = tuple(map(inverse_segment, (k, v)))
316
+
243
317
  # take care of persistent memory key / values
244
318
 
245
319
  pmk, pmv = repeat(self.persistent_memory, 'kv h n d -> kv b h n d', b = batch)
@@ -271,7 +345,7 @@ class SegmentedAttention(Module):
271
345
  if exists(output_gating):
272
346
  out = out * output_gating
273
347
 
274
- return out, orig_v
348
+ return out, AttnIntermediates(orig_v, next_cache)
275
349
 
276
350
  def forward(
277
351
  self,
@@ -279,10 +353,17 @@ class SegmentedAttention(Module):
279
353
  value_residual = None,
280
354
  flex_attn_fn: Callable | None = None,
281
355
  disable_flex_attn = False,
282
- output_gating = None
356
+ output_gating = None,
357
+ cache = None
283
358
  ):
359
+ is_inferencing = exists(cache)
360
+
361
+ if is_inferencing:
362
+ assert seq.shape[-2] == 1
363
+ return self.forward_inference(seq, cache, value_residual, output_gating = output_gating)
364
+
284
365
  if seq.is_cuda and self.use_flex_attn and not disable_flex_attn:
285
- return self.forward_flex(seq, value_residual, flex_attn_fn, output_gating = output_gating)
366
+ return self.forward_flex(seq, value_residual, flex_attn_fn, output_gating = output_gating, cache = cache)
286
367
 
287
368
  assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
288
369
 
@@ -310,6 +391,10 @@ class SegmentedAttention(Module):
310
391
  mix = self.to_learned_v_mix(seq)
311
392
  v = v.lerp(value_residual, mix)
312
393
 
394
+ # caching
395
+
396
+ next_cache = tuple(map(inverse_segment, (k, v)))
397
+
313
398
  # relative positions
314
399
 
315
400
  q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
@@ -369,7 +454,7 @@ class SegmentedAttention(Module):
369
454
  if exists(output_gating):
370
455
  out = out * output_gating
371
456
 
372
- return out, orig_v
457
+ return out, AttnIntermediates(orig_v, next_cache)
373
458
 
374
459
  # MAC transformer
375
460
 
@@ -413,6 +498,7 @@ class MemoryAsContextTransformer(Module):
413
498
  # maybe sliding window attn
414
499
 
415
500
  self.sliding_window_attn = sliding_window_attn
501
+ self.attn_window_size = segment_len + num_longterm_mem_tokens
416
502
 
417
503
  # hyper conection
418
504
 
@@ -487,7 +573,6 @@ class MemoryAsContextTransformer(Module):
487
573
  assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
488
574
  self.use_flex_attn = use_flex_attn
489
575
 
490
- self.segment_len = segment_len
491
576
  self.num_persist_mem_tokens = num_persist_mem_tokens
492
577
 
493
578
  @torch.no_grad()
@@ -569,7 +654,7 @@ class MemoryAsContextTransformer(Module):
569
654
 
570
655
  # math
571
656
 
572
- batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens
657
+ batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
573
658
 
574
659
  # token embedding
575
660
 
@@ -603,6 +688,12 @@ class MemoryAsContextTransformer(Module):
603
688
  block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens, self.sliding_window_attn)
604
689
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
605
690
 
691
+ # kv caching
692
+
693
+ is_inferencing = exists(cache)
694
+ cache = iter(default(cache, []))
695
+ next_kv_caches = []
696
+
606
697
  # value residual
607
698
 
608
699
  value_residual = None
@@ -611,6 +702,11 @@ class MemoryAsContextTransformer(Module):
611
702
 
612
703
  kv_recon_losses = self.zero
613
704
 
705
+ # when inferencing, only do one token at a time
706
+
707
+ if is_inferencing:
708
+ x = x[:, -1:]
709
+
614
710
  # expand and reduce streams for hyper connections
615
711
 
616
712
  x = self.expand_streams(x)
@@ -620,6 +716,8 @@ class MemoryAsContextTransformer(Module):
620
716
  retrieved = None
621
717
  attn_out_gates = None
622
718
 
719
+ # maybe neural memory
720
+
623
721
  if exists(mem):
624
722
  retrieved, mem_kv_aux_loss = mem(x, return_aux_kv_loss = True)
625
723
  kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
@@ -631,16 +729,19 @@ class MemoryAsContextTransformer(Module):
631
729
 
632
730
  # attention
633
731
 
634
- x, values = attn(
732
+ x, (values, next_kv_cache) = attn(
635
733
  x,
636
734
  value_residual = value_residual,
637
735
  disable_flex_attn = disable_flex_attn,
638
736
  flex_attn_fn = flex_attn_fn,
639
- output_gating = attn_out_gates
737
+ output_gating = attn_out_gates,
738
+ cache = next(cache, None)
640
739
  )
641
740
 
642
741
  value_residual = default(value_residual, values)
643
742
 
743
+ next_kv_caches.append(next_kv_cache)
744
+
644
745
  # feedforward
645
746
 
646
747
  x = ff(x)
@@ -665,7 +766,16 @@ class MemoryAsContextTransformer(Module):
665
766
  if not return_cache:
666
767
  return logits
667
768
 
668
- return logits, cache
769
+ next_kv_caches = stack([stack(kv_cache) for kv_cache in next_kv_caches])
770
+
771
+ # handle kv cache length depending on local attention type
772
+
773
+ next_kv_caches = next_kv_caches[..., -attn_window_size:, :]
774
+
775
+ if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
776
+ next_kv_caches = next_kv_caches[..., 0:0, :]
777
+
778
+ return logits, next_kv_caches
669
779
 
670
780
  ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
671
781
 
File without changes