titans-pytorch 0.1.7__tar.gz → 0.1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/PKG-INFO +1 -1
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/pyproject.toml +1 -1
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/tests/test_titans.py +13 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/titans_pytorch/titans.py +16 -8
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/.gitignore +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/LICENSE +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/README.md +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/data/README.md +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/fig1.png +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/fig2.png +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.1.7 → titans_pytorch-0.1.8}/train_mac.py +0 -0
|
@@ -53,6 +53,19 @@ def test_titans_attn_memory():
|
|
|
53
53
|
|
|
54
54
|
assert seq.shape == retrieved.shape
|
|
55
55
|
|
|
56
|
+
def test_retrieve_store_diff_seq():
|
|
57
|
+
mem = NeuralMemory(
|
|
58
|
+
dim = 384,
|
|
59
|
+
chunk_size = (64, 32),
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
retrieve_seq = torch.randn(2, 64 * 64, 384)
|
|
63
|
+
store_seq = torch.randn(2, 64 * 32, 384)
|
|
64
|
+
|
|
65
|
+
retrieved = mem(retrieve_seq, store_seq = store_seq)
|
|
66
|
+
|
|
67
|
+
assert retrieve_seq.shape == retrieved.shape
|
|
68
|
+
|
|
56
69
|
@pytest.mark.parametrize('seq_len', (1023, 17))
|
|
57
70
|
@pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
|
|
58
71
|
@pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
|
|
@@ -44,6 +44,9 @@ def default(v, d):
|
|
|
44
44
|
def identity(t):
|
|
45
45
|
return t
|
|
46
46
|
|
|
47
|
+
def pair(v):
|
|
48
|
+
return (v, v) if not isinstance(v, tuple) else v
|
|
49
|
+
|
|
47
50
|
def round_down_multiple(seq, mult):
|
|
48
51
|
return seq // mult * mult
|
|
49
52
|
|
|
@@ -290,7 +293,7 @@ class NeuralMemory(Module):
|
|
|
290
293
|
def __init__(
|
|
291
294
|
self,
|
|
292
295
|
dim,
|
|
293
|
-
chunk_size = 1,
|
|
296
|
+
chunk_size: int | tuple[int, int] = 1,
|
|
294
297
|
dim_head = None,
|
|
295
298
|
heads = 1,
|
|
296
299
|
model: Module | None = None,
|
|
@@ -313,6 +316,8 @@ class NeuralMemory(Module):
|
|
|
313
316
|
super().__init__()
|
|
314
317
|
dim_head = default(dim_head, dim)
|
|
315
318
|
|
|
319
|
+
self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
|
|
320
|
+
|
|
316
321
|
# norms
|
|
317
322
|
|
|
318
323
|
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
|
@@ -380,6 +385,10 @@ class NeuralMemory(Module):
|
|
|
380
385
|
self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
|
|
381
386
|
nn.init.normal_(self.empty_memory_embed, std = 0.02)
|
|
382
387
|
|
|
388
|
+
# `chunk_size` refers to chunk size used for storing to memory model weights
|
|
389
|
+
|
|
390
|
+
chunk_size = self.store_chunk_size
|
|
391
|
+
|
|
383
392
|
# whether to use averaging of chunks, or attention pooling
|
|
384
393
|
|
|
385
394
|
if not attn_pool_chunks:
|
|
@@ -451,11 +460,11 @@ class NeuralMemory(Module):
|
|
|
451
460
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
|
|
452
461
|
return_aux_kv_loss = False
|
|
453
462
|
):
|
|
454
|
-
seq_len = seq.shape[-2]
|
|
463
|
+
seq_len, chunk_size = seq.shape[-2], self.store_chunk_size
|
|
455
464
|
|
|
456
465
|
# handle edge case
|
|
457
466
|
|
|
458
|
-
if seq_len <
|
|
467
|
+
if seq_len < chunk_size:
|
|
459
468
|
past_weight, _ = past_state
|
|
460
469
|
return TensorDict(past_weight).clone().zero_(), self.zero
|
|
461
470
|
|
|
@@ -464,8 +473,7 @@ class NeuralMemory(Module):
|
|
|
464
473
|
# curtail sequence by multiple of the chunk size
|
|
465
474
|
# only a complete chunk of the sequence provides the memory for the next chunk
|
|
466
475
|
|
|
467
|
-
|
|
468
|
-
round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
|
|
476
|
+
round_down_seq_len = round_down_multiple(seq_len, chunk_size)
|
|
469
477
|
|
|
470
478
|
seq = seq[:, :round_down_seq_len]
|
|
471
479
|
|
|
@@ -597,12 +605,12 @@ class NeuralMemory(Module):
|
|
|
597
605
|
seq,
|
|
598
606
|
past_weights: dict[str, Tensor] | None = None,
|
|
599
607
|
):
|
|
600
|
-
chunk_size = self.
|
|
608
|
+
chunk_size = self.retrieve_chunk_size
|
|
601
609
|
batch, seq_len = seq.shape[:2]
|
|
602
610
|
|
|
603
611
|
seq = self.retrieve_norm(seq)
|
|
604
612
|
|
|
605
|
-
if seq_len <
|
|
613
|
+
if seq_len < chunk_size:
|
|
606
614
|
return self.init_empty_memory_embed(batch, seq_len)
|
|
607
615
|
|
|
608
616
|
seq = seq[:, (chunk_size - 1):]
|
|
@@ -674,7 +682,7 @@ class NeuralMemory(Module):
|
|
|
674
682
|
):
|
|
675
683
|
batch, seq_len = seq.shape[:2]
|
|
676
684
|
|
|
677
|
-
if seq_len < self.
|
|
685
|
+
if seq_len < self.retrieve_chunk_size:
|
|
678
686
|
out = self.init_empty_memory_embed(batch, seq_len)
|
|
679
687
|
|
|
680
688
|
if not return_aux_kv_loss:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|