titans-pytorch 0.1.7__tar.gz → 0.1.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/PKG-INFO +1 -1
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/pyproject.toml +1 -1
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/tests/test_titans.py +13 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/titans_pytorch/mac_transformer.py +10 -2
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/titans_pytorch/titans.py +18 -8
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/.gitignore +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/LICENSE +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/README.md +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/data/README.md +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/fig1.png +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/fig2.png +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.9}/train_mac.py +0 -0
|
@@ -53,6 +53,19 @@ def test_titans_attn_memory():
|
|
|
53
53
|
|
|
54
54
|
assert seq.shape == retrieved.shape
|
|
55
55
|
|
|
56
|
+
def test_retrieve_store_diff_seq():
|
|
57
|
+
mem = NeuralMemory(
|
|
58
|
+
dim = 384,
|
|
59
|
+
chunk_size = (64, 32),
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
retrieve_seq = torch.randn(2, 64 * 64, 384)
|
|
63
|
+
store_seq = torch.randn(2, 64 * 32, 384)
|
|
64
|
+
|
|
65
|
+
retrieved = mem(retrieve_seq, store_seq = store_seq)
|
|
66
|
+
|
|
67
|
+
assert retrieve_seq.shape == retrieved.shape
|
|
68
|
+
|
|
56
69
|
@pytest.mark.parametrize('seq_len', (1023, 17))
|
|
57
70
|
@pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
|
|
58
71
|
@pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
|
|
@@ -83,6 +83,14 @@ def identity(t):
|
|
|
83
83
|
def round_up_multiple(seq, mult):
|
|
84
84
|
return ceil(seq / mult) * mult
|
|
85
85
|
|
|
86
|
+
def pack_with_inverse(t, pattern):
|
|
87
|
+
packed, packed_shape = pack(t, pattern)
|
|
88
|
+
|
|
89
|
+
def inverse(out, inv_pattern = None):
|
|
90
|
+
return unpack(out, packed_shape, default(inv_pattern, pattern))
|
|
91
|
+
|
|
92
|
+
return packed, inverse
|
|
93
|
+
|
|
86
94
|
def pad_at_dim(t, pad, dim = -1, value = 0.):
|
|
87
95
|
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
|
88
96
|
zeros = ((0, 0) * dims_from_right)
|
|
@@ -576,7 +584,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
576
584
|
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
|
577
585
|
|
|
578
586
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
|
579
|
-
x,
|
|
587
|
+
x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
|
|
580
588
|
|
|
581
589
|
x = inverse_segment(x)
|
|
582
590
|
|
|
@@ -634,7 +642,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
634
642
|
|
|
635
643
|
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len + num_longterm_mem_tokens)
|
|
636
644
|
|
|
637
|
-
x, _ =
|
|
645
|
+
x, _ = inverse_pack_mems(x)
|
|
638
646
|
|
|
639
647
|
x = inverse_segment(x)
|
|
640
648
|
|
|
@@ -44,6 +44,9 @@ def default(v, d):
|
|
|
44
44
|
def identity(t):
|
|
45
45
|
return t
|
|
46
46
|
|
|
47
|
+
def pair(v):
|
|
48
|
+
return (v, v) if not isinstance(v, tuple) else v
|
|
49
|
+
|
|
47
50
|
def round_down_multiple(seq, mult):
|
|
48
51
|
return seq // mult * mult
|
|
49
52
|
|
|
@@ -290,7 +293,7 @@ class NeuralMemory(Module):
|
|
|
290
293
|
def __init__(
|
|
291
294
|
self,
|
|
292
295
|
dim,
|
|
293
|
-
chunk_size = 1,
|
|
296
|
+
chunk_size: int | tuple[int, int] = 1,
|
|
294
297
|
dim_head = None,
|
|
295
298
|
heads = 1,
|
|
296
299
|
model: Module | None = None,
|
|
@@ -313,6 +316,8 @@ class NeuralMemory(Module):
|
|
|
313
316
|
super().__init__()
|
|
314
317
|
dim_head = default(dim_head, dim)
|
|
315
318
|
|
|
319
|
+
self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
|
|
320
|
+
|
|
316
321
|
# norms
|
|
317
322
|
|
|
318
323
|
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
|
@@ -380,8 +385,14 @@ class NeuralMemory(Module):
|
|
|
380
385
|
self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
|
|
381
386
|
nn.init.normal_(self.empty_memory_embed, std = 0.02)
|
|
382
387
|
|
|
388
|
+
# `chunk_size` refers to chunk size used for storing to memory model weights
|
|
389
|
+
|
|
390
|
+
chunk_size = self.store_chunk_size
|
|
391
|
+
|
|
383
392
|
# whether to use averaging of chunks, or attention pooling
|
|
384
393
|
|
|
394
|
+
assert not (attn_pool_chunks and chunk_size == 1), '`attn_pool_chunks` cannot be set to True if `chunk_size` is set to 1'
|
|
395
|
+
|
|
385
396
|
if not attn_pool_chunks:
|
|
386
397
|
chunk_reduce_module = Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size)
|
|
387
398
|
else:
|
|
@@ -451,11 +462,11 @@ class NeuralMemory(Module):
|
|
|
451
462
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
|
|
452
463
|
return_aux_kv_loss = False
|
|
453
464
|
):
|
|
454
|
-
seq_len = seq.shape[-2]
|
|
465
|
+
seq_len, chunk_size = seq.shape[-2], self.store_chunk_size
|
|
455
466
|
|
|
456
467
|
# handle edge case
|
|
457
468
|
|
|
458
|
-
if seq_len <
|
|
469
|
+
if seq_len < chunk_size:
|
|
459
470
|
past_weight, _ = past_state
|
|
460
471
|
return TensorDict(past_weight).clone().zero_(), self.zero
|
|
461
472
|
|
|
@@ -464,8 +475,7 @@ class NeuralMemory(Module):
|
|
|
464
475
|
# curtail sequence by multiple of the chunk size
|
|
465
476
|
# only a complete chunk of the sequence provides the memory for the next chunk
|
|
466
477
|
|
|
467
|
-
|
|
468
|
-
round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
|
|
478
|
+
round_down_seq_len = round_down_multiple(seq_len, chunk_size)
|
|
469
479
|
|
|
470
480
|
seq = seq[:, :round_down_seq_len]
|
|
471
481
|
|
|
@@ -597,12 +607,12 @@ class NeuralMemory(Module):
|
|
|
597
607
|
seq,
|
|
598
608
|
past_weights: dict[str, Tensor] | None = None,
|
|
599
609
|
):
|
|
600
|
-
chunk_size = self.
|
|
610
|
+
chunk_size = self.retrieve_chunk_size
|
|
601
611
|
batch, seq_len = seq.shape[:2]
|
|
602
612
|
|
|
603
613
|
seq = self.retrieve_norm(seq)
|
|
604
614
|
|
|
605
|
-
if seq_len <
|
|
615
|
+
if seq_len < chunk_size:
|
|
606
616
|
return self.init_empty_memory_embed(batch, seq_len)
|
|
607
617
|
|
|
608
618
|
seq = seq[:, (chunk_size - 1):]
|
|
@@ -674,7 +684,7 @@ class NeuralMemory(Module):
|
|
|
674
684
|
):
|
|
675
685
|
batch, seq_len = seq.shape[:2]
|
|
676
686
|
|
|
677
|
-
if seq_len < self.
|
|
687
|
+
if seq_len < self.retrieve_chunk_size:
|
|
678
688
|
out = self.init_empty_memory_embed(batch, seq_len)
|
|
679
689
|
|
|
680
690
|
if not return_aux_kv_loss:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|