titans-pytorch 0.1.18__tar.gz → 0.1.20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/PKG-INFO +1 -1
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/pyproject.toml +1 -1
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/tests/test_titans.py +6 -2
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/titans_pytorch/mac_transformer.py +122 -12
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/.gitignore +0 -0
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/LICENSE +0 -0
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/README.md +0 -0
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/data/README.md +0 -0
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/fig1.png +0 -0
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/fig2.png +0 -0
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.1.18 → titans_pytorch-0.1.20}/train_mac.py +0 -0
|
@@ -107,14 +107,18 @@ def test_mac(
|
|
|
107
107
|
logits = transformer(x)
|
|
108
108
|
assert logits.shape == (1, seq_len, 256)
|
|
109
109
|
|
|
110
|
-
|
|
110
|
+
@pytest.mark.parametrize('sliding', (False, True))
|
|
111
|
+
def test_mac_sampling(sliding):
|
|
111
112
|
transformer = MemoryAsContextTransformer(
|
|
112
113
|
num_tokens = 256,
|
|
113
114
|
dim = 256,
|
|
114
115
|
depth = 2,
|
|
115
116
|
segment_len = 32,
|
|
116
117
|
num_persist_mem_tokens = 4,
|
|
117
|
-
num_longterm_mem_tokens =
|
|
118
|
+
num_longterm_mem_tokens = 0,
|
|
119
|
+
sliding_window_attn = sliding,
|
|
120
|
+
neural_memory_layers = (),
|
|
121
|
+
neural_mem_gate_attn_output = False
|
|
118
122
|
)
|
|
119
123
|
|
|
120
124
|
ids = torch.randint(0, 256, (1, 1023))
|
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
from typing import Callable
|
|
3
|
+
|
|
3
4
|
from math import ceil
|
|
4
5
|
from functools import partial
|
|
6
|
+
from collections import namedtuple
|
|
5
7
|
|
|
6
8
|
import tqdm
|
|
7
9
|
|
|
8
10
|
import torch
|
|
9
|
-
from torch import nn, cat
|
|
11
|
+
from torch import nn, stack, cat
|
|
10
12
|
import torch.nn.functional as F
|
|
11
13
|
from torch.nn import Module, ModuleList, Linear
|
|
12
14
|
|
|
@@ -69,6 +71,8 @@ from titans_pytorch.titans import NeuralMemory
|
|
|
69
71
|
|
|
70
72
|
LinearNoBias = partial(Linear, bias = False)
|
|
71
73
|
|
|
74
|
+
AttnIntermediates = namedtuple('AttnIntermediates', ('value_residual', 'cached_key_values'))
|
|
75
|
+
|
|
72
76
|
# helpers
|
|
73
77
|
|
|
74
78
|
def exists(v):
|
|
@@ -80,6 +84,9 @@ def default(v, d):
|
|
|
80
84
|
def identity(t):
|
|
81
85
|
return t
|
|
82
86
|
|
|
87
|
+
def divisible_by(num, den):
|
|
88
|
+
return (num % den) == 0
|
|
89
|
+
|
|
83
90
|
def round_up_multiple(seq, mult):
|
|
84
91
|
return ceil(seq / mult) * mult
|
|
85
92
|
|
|
@@ -111,7 +118,7 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
|
|
|
111
118
|
|
|
112
119
|
def inverse(out):
|
|
113
120
|
if fold_into_batch:
|
|
114
|
-
out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
|
|
121
|
+
out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
|
|
115
122
|
|
|
116
123
|
if needs_pad:
|
|
117
124
|
out = out[..., :-padding, :]
|
|
@@ -213,12 +220,75 @@ class SegmentedAttention(Module):
|
|
|
213
220
|
self.segment_len = segment_len
|
|
214
221
|
self.num_persist_mem_tokens = num_persist_mem_tokens
|
|
215
222
|
|
|
223
|
+
def forward_inference(
|
|
224
|
+
self,
|
|
225
|
+
token,
|
|
226
|
+
cache,
|
|
227
|
+
value_residual = None,
|
|
228
|
+
output_gating = None,
|
|
229
|
+
):
|
|
230
|
+
batch = token.shape[0]
|
|
231
|
+
|
|
232
|
+
# attention
|
|
233
|
+
|
|
234
|
+
token = self.norm(token)
|
|
235
|
+
|
|
236
|
+
q, k, v = self.to_qkv(token).chunk(3, dim = -1)
|
|
237
|
+
q, k, v = map(self.split_heads, (q, k, v))
|
|
238
|
+
|
|
239
|
+
# value residual
|
|
240
|
+
|
|
241
|
+
orig_v = v
|
|
242
|
+
|
|
243
|
+
if exists(self.to_learned_v_mix):
|
|
244
|
+
mix = self.to_learned_v_mix(token)
|
|
245
|
+
v = v.lerp(value_residual, mix)
|
|
246
|
+
|
|
247
|
+
# caching
|
|
248
|
+
|
|
249
|
+
ck, cv = cache
|
|
250
|
+
k = cat((ck, k), dim = -2)
|
|
251
|
+
v = cat((cv, v), dim = -2)
|
|
252
|
+
|
|
253
|
+
next_cache = (k, v)
|
|
254
|
+
|
|
255
|
+
# relative positions
|
|
256
|
+
|
|
257
|
+
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
|
|
258
|
+
|
|
259
|
+
# fold
|
|
260
|
+
|
|
261
|
+
q, k, v = tuple(rearrange(t, 'b h n d -> b h n d') for t in (q, k, v))
|
|
262
|
+
|
|
263
|
+
# take care of persistent memory key / values
|
|
264
|
+
|
|
265
|
+
pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = k.shape[0])
|
|
266
|
+
|
|
267
|
+
# persistent memory
|
|
268
|
+
|
|
269
|
+
k = cat((pmk, k), dim = -2)
|
|
270
|
+
v = cat((pmv, v), dim = -2)
|
|
271
|
+
|
|
272
|
+
# attention
|
|
273
|
+
|
|
274
|
+
out, _ = self.attend(q, k, v)
|
|
275
|
+
|
|
276
|
+
out = self.merge_heads(out)
|
|
277
|
+
|
|
278
|
+
out = self.to_out(out)
|
|
279
|
+
|
|
280
|
+
if exists(output_gating):
|
|
281
|
+
out = out * output_gating
|
|
282
|
+
|
|
283
|
+
return out, AttnIntermediates(orig_v, next_cache)
|
|
284
|
+
|
|
216
285
|
def forward_flex(
|
|
217
286
|
self,
|
|
218
287
|
seq,
|
|
219
288
|
value_residual = None,
|
|
220
289
|
flex_attn_fn: Callable | None = None,
|
|
221
|
-
output_gating = None
|
|
290
|
+
output_gating = None,
|
|
291
|
+
cache = None
|
|
222
292
|
):
|
|
223
293
|
|
|
224
294
|
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
@@ -240,6 +310,10 @@ class SegmentedAttention(Module):
|
|
|
240
310
|
mix = self.to_learned_v_mix(seq)
|
|
241
311
|
v = v.lerp(value_residual, mix)
|
|
242
312
|
|
|
313
|
+
# caching
|
|
314
|
+
|
|
315
|
+
next_cache = tuple(map(inverse_segment, (k, v)))
|
|
316
|
+
|
|
243
317
|
# take care of persistent memory key / values
|
|
244
318
|
|
|
245
319
|
pmk, pmv = repeat(self.persistent_memory, 'kv h n d -> kv b h n d', b = batch)
|
|
@@ -271,7 +345,7 @@ class SegmentedAttention(Module):
|
|
|
271
345
|
if exists(output_gating):
|
|
272
346
|
out = out * output_gating
|
|
273
347
|
|
|
274
|
-
return out, orig_v
|
|
348
|
+
return out, AttnIntermediates(orig_v, next_cache)
|
|
275
349
|
|
|
276
350
|
def forward(
|
|
277
351
|
self,
|
|
@@ -279,10 +353,17 @@ class SegmentedAttention(Module):
|
|
|
279
353
|
value_residual = None,
|
|
280
354
|
flex_attn_fn: Callable | None = None,
|
|
281
355
|
disable_flex_attn = False,
|
|
282
|
-
output_gating = None
|
|
356
|
+
output_gating = None,
|
|
357
|
+
cache = None
|
|
283
358
|
):
|
|
359
|
+
is_inferencing = exists(cache)
|
|
360
|
+
|
|
361
|
+
if is_inferencing:
|
|
362
|
+
assert seq.shape[-2] == 1
|
|
363
|
+
return self.forward_inference(seq, cache, value_residual, output_gating = output_gating)
|
|
364
|
+
|
|
284
365
|
if seq.is_cuda and self.use_flex_attn and not disable_flex_attn:
|
|
285
|
-
return self.forward_flex(seq, value_residual, flex_attn_fn, output_gating = output_gating)
|
|
366
|
+
return self.forward_flex(seq, value_residual, flex_attn_fn, output_gating = output_gating, cache = cache)
|
|
286
367
|
|
|
287
368
|
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
288
369
|
|
|
@@ -310,6 +391,10 @@ class SegmentedAttention(Module):
|
|
|
310
391
|
mix = self.to_learned_v_mix(seq)
|
|
311
392
|
v = v.lerp(value_residual, mix)
|
|
312
393
|
|
|
394
|
+
# caching
|
|
395
|
+
|
|
396
|
+
next_cache = tuple(map(inverse_segment, (k, v)))
|
|
397
|
+
|
|
313
398
|
# relative positions
|
|
314
399
|
|
|
315
400
|
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
|
|
@@ -369,7 +454,7 @@ class SegmentedAttention(Module):
|
|
|
369
454
|
if exists(output_gating):
|
|
370
455
|
out = out * output_gating
|
|
371
456
|
|
|
372
|
-
return out, orig_v
|
|
457
|
+
return out, AttnIntermediates(orig_v, next_cache)
|
|
373
458
|
|
|
374
459
|
# MAC transformer
|
|
375
460
|
|
|
@@ -413,6 +498,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
413
498
|
# maybe sliding window attn
|
|
414
499
|
|
|
415
500
|
self.sliding_window_attn = sliding_window_attn
|
|
501
|
+
self.attn_window_size = segment_len + num_longterm_mem_tokens
|
|
416
502
|
|
|
417
503
|
# hyper conection
|
|
418
504
|
|
|
@@ -487,7 +573,6 @@ class MemoryAsContextTransformer(Module):
|
|
|
487
573
|
assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
|
|
488
574
|
self.use_flex_attn = use_flex_attn
|
|
489
575
|
|
|
490
|
-
self.segment_len = segment_len
|
|
491
576
|
self.num_persist_mem_tokens = num_persist_mem_tokens
|
|
492
577
|
|
|
493
578
|
@torch.no_grad()
|
|
@@ -569,7 +654,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
569
654
|
|
|
570
655
|
# math
|
|
571
656
|
|
|
572
|
-
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens
|
|
657
|
+
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
|
|
573
658
|
|
|
574
659
|
# token embedding
|
|
575
660
|
|
|
@@ -603,6 +688,12 @@ class MemoryAsContextTransformer(Module):
|
|
|
603
688
|
block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens, self.sliding_window_attn)
|
|
604
689
|
flex_attn_fn = partial(flex_attention, block_mask = block_mask)
|
|
605
690
|
|
|
691
|
+
# kv caching
|
|
692
|
+
|
|
693
|
+
is_inferencing = exists(cache)
|
|
694
|
+
cache = iter(default(cache, []))
|
|
695
|
+
next_kv_caches = []
|
|
696
|
+
|
|
606
697
|
# value residual
|
|
607
698
|
|
|
608
699
|
value_residual = None
|
|
@@ -611,6 +702,11 @@ class MemoryAsContextTransformer(Module):
|
|
|
611
702
|
|
|
612
703
|
kv_recon_losses = self.zero
|
|
613
704
|
|
|
705
|
+
# when inferencing, only do one token at a time
|
|
706
|
+
|
|
707
|
+
if is_inferencing:
|
|
708
|
+
x = x[:, -1:]
|
|
709
|
+
|
|
614
710
|
# expand and reduce streams for hyper connections
|
|
615
711
|
|
|
616
712
|
x = self.expand_streams(x)
|
|
@@ -620,6 +716,8 @@ class MemoryAsContextTransformer(Module):
|
|
|
620
716
|
retrieved = None
|
|
621
717
|
attn_out_gates = None
|
|
622
718
|
|
|
719
|
+
# maybe neural memory
|
|
720
|
+
|
|
623
721
|
if exists(mem):
|
|
624
722
|
retrieved, mem_kv_aux_loss = mem(x, return_aux_kv_loss = True)
|
|
625
723
|
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
|
@@ -631,16 +729,19 @@ class MemoryAsContextTransformer(Module):
|
|
|
631
729
|
|
|
632
730
|
# attention
|
|
633
731
|
|
|
634
|
-
x, values = attn(
|
|
732
|
+
x, (values, next_kv_cache) = attn(
|
|
635
733
|
x,
|
|
636
734
|
value_residual = value_residual,
|
|
637
735
|
disable_flex_attn = disable_flex_attn,
|
|
638
736
|
flex_attn_fn = flex_attn_fn,
|
|
639
|
-
output_gating = attn_out_gates
|
|
737
|
+
output_gating = attn_out_gates,
|
|
738
|
+
cache = next(cache, None)
|
|
640
739
|
)
|
|
641
740
|
|
|
642
741
|
value_residual = default(value_residual, values)
|
|
643
742
|
|
|
743
|
+
next_kv_caches.append(next_kv_cache)
|
|
744
|
+
|
|
644
745
|
# feedforward
|
|
645
746
|
|
|
646
747
|
x = ff(x)
|
|
@@ -665,7 +766,16 @@ class MemoryAsContextTransformer(Module):
|
|
|
665
766
|
if not return_cache:
|
|
666
767
|
return logits
|
|
667
768
|
|
|
668
|
-
|
|
769
|
+
next_kv_caches = stack([stack(kv_cache) for kv_cache in next_kv_caches])
|
|
770
|
+
|
|
771
|
+
# handle kv cache length depending on local attention type
|
|
772
|
+
|
|
773
|
+
next_kv_caches = next_kv_caches[..., -attn_window_size:, :]
|
|
774
|
+
|
|
775
|
+
if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
|
|
776
|
+
next_kv_caches = next_kv_caches[..., 0:0, :]
|
|
777
|
+
|
|
778
|
+
return logits, next_kv_caches
|
|
669
779
|
|
|
670
780
|
ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
|
671
781
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|