titans-pytorch 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +58 -6
- titans_pytorch/titans.py +6 -1
- {titans_pytorch-0.1.0.dist-info → titans_pytorch-0.1.2.dist-info}/METADATA +8 -3
- titans_pytorch-0.1.2.dist-info/RECORD +8 -0
- titans_pytorch-0.1.0.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.0.dist-info → titans_pytorch-0.1.2.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.0.dist-info → titans_pytorch-0.1.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,6 +3,8 @@ from typing import Callable
|
|
|
3
3
|
from math import ceil
|
|
4
4
|
from functools import partial
|
|
5
5
|
|
|
6
|
+
import tqdm
|
|
7
|
+
|
|
6
8
|
import torch
|
|
7
9
|
from torch import nn, cat
|
|
8
10
|
import torch.nn.functional as F
|
|
@@ -88,12 +90,6 @@ def pad_at_dim(t, pad, dim = -1, value = 0.):
|
|
|
88
90
|
|
|
89
91
|
def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
|
|
90
92
|
batch, seq_len = seq.shape[:2]
|
|
91
|
-
|
|
92
|
-
need_segment = seq_len >= segment_len
|
|
93
|
-
|
|
94
|
-
if not need_segment:
|
|
95
|
-
return seq, identity
|
|
96
|
-
|
|
97
93
|
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
|
98
94
|
|
|
99
95
|
padding = next_seq_len_mult - seq_len
|
|
@@ -116,6 +112,29 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
|
|
|
116
112
|
|
|
117
113
|
return seq, inverse
|
|
118
114
|
|
|
115
|
+
# sampling related
|
|
116
|
+
|
|
117
|
+
def log(t, eps = 1e-20):
|
|
118
|
+
return torch.log(t.clamp(min = eps))
|
|
119
|
+
|
|
120
|
+
def gumbel_noise(t):
|
|
121
|
+
noise = torch.rand_like(t)
|
|
122
|
+
return -log(-log(noise))
|
|
123
|
+
|
|
124
|
+
def gumbel_sample(t, temperature = 1.):
|
|
125
|
+
if temperature > 0.:
|
|
126
|
+
t = t / temperature + gumbel_noise(t)
|
|
127
|
+
return t.argmax(dim = -1, keepdim = True)
|
|
128
|
+
|
|
129
|
+
# min_p
|
|
130
|
+
# https://arxiv.org/abs/2407.01082
|
|
131
|
+
|
|
132
|
+
def min_p_filter(logits, min_p = 0.1):
|
|
133
|
+
probs = logits.softmax(dim = -1)
|
|
134
|
+
max_probs = probs.amax(dim = -1, keepdim = True)
|
|
135
|
+
limit = min_p * max_probs
|
|
136
|
+
return torch.where(probs < limit, float('-inf'), logits)
|
|
137
|
+
|
|
119
138
|
# feedforward and attention
|
|
120
139
|
|
|
121
140
|
class GEGLU(Module):
|
|
@@ -500,6 +519,39 @@ class MemoryAsContextTransformer(Module):
|
|
|
500
519
|
self.segment_len = segment_len
|
|
501
520
|
self.num_persist_mem_tokens = num_persist_mem_tokens
|
|
502
521
|
|
|
522
|
+
@torch.no_grad()
|
|
523
|
+
def sample(
|
|
524
|
+
self,
|
|
525
|
+
prompt: Tensor,
|
|
526
|
+
seq_len: int,
|
|
527
|
+
temperature = 1.5,
|
|
528
|
+
filter_fn: Callable = min_p_filter,
|
|
529
|
+
filter_kwargs: dict = dict(
|
|
530
|
+
min_p = 0.1,
|
|
531
|
+
),
|
|
532
|
+
show_progress = True
|
|
533
|
+
):
|
|
534
|
+
was_training = self.training
|
|
535
|
+
self.eval()
|
|
536
|
+
|
|
537
|
+
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
|
|
538
|
+
sample_num_times = max(0, seq_len - prompt_seq_len)
|
|
539
|
+
|
|
540
|
+
iter_wrap = tqdm.tqdm if show_progress else identity
|
|
541
|
+
|
|
542
|
+
for _ in iter_wrap(range(sample_num_times)):
|
|
543
|
+
logits = self.forward(out, disable_flex_attn = True)
|
|
544
|
+
logits = logits[:, -1]
|
|
545
|
+
|
|
546
|
+
logits = filter_fn(logits, **filter_kwargs)
|
|
547
|
+
sample = gumbel_sample(logits, temperature = temperature)
|
|
548
|
+
|
|
549
|
+
out = torch.cat((out, sample), dim = -1)
|
|
550
|
+
|
|
551
|
+
self.train(was_training)
|
|
552
|
+
|
|
553
|
+
return out[..., prompt_seq_len:]
|
|
554
|
+
|
|
503
555
|
def forward(
|
|
504
556
|
self,
|
|
505
557
|
x,
|
titans_pytorch/titans.py
CHANGED
|
@@ -592,7 +592,12 @@ class NeuralMemory(Module):
|
|
|
592
592
|
batch, seq_len = seq.shape[:2]
|
|
593
593
|
|
|
594
594
|
if seq_len < self.chunk_size:
|
|
595
|
-
|
|
595
|
+
out = self.init_empty_memory_embed(batch, seq_len)
|
|
596
|
+
|
|
597
|
+
if not return_aux_kv_loss:
|
|
598
|
+
return out
|
|
599
|
+
|
|
600
|
+
return out, self.zero
|
|
596
601
|
|
|
597
602
|
if exists(past_state):
|
|
598
603
|
past_state = tuple(TensorDict(d) for d in past_state)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -43,9 +43,9 @@ Requires-Dist: ninja
|
|
|
43
43
|
Requires-Dist: rotary-embedding-torch
|
|
44
44
|
Requires-Dist: tensordict
|
|
45
45
|
Requires-Dist: torch>=2.2
|
|
46
|
+
Requires-Dist: tqdm
|
|
46
47
|
Requires-Dist: x-transformers
|
|
47
48
|
Provides-Extra: examples
|
|
48
|
-
Requires-Dist: tqdm; extra == 'examples'
|
|
49
49
|
Requires-Dist: wandb; extra == 'examples'
|
|
50
50
|
Provides-Extra: test
|
|
51
51
|
Requires-Dist: pytest; extra == 'test'
|
|
@@ -104,7 +104,12 @@ transformer = MemoryAsContextTransformer(
|
|
|
104
104
|
|
|
105
105
|
token_ids = torch.randint(0, 256, (1, 1023))
|
|
106
106
|
|
|
107
|
-
|
|
107
|
+
loss = transformer(token_ids, return_loss = True) # (1, 1023, 256)
|
|
108
|
+
loss.backward()
|
|
109
|
+
|
|
110
|
+
# after much training
|
|
111
|
+
|
|
112
|
+
sampled = transformer.sample(token_ids[:, :4], 512)
|
|
108
113
|
```
|
|
109
114
|
|
|
110
115
|
## Experiments
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
|
|
4
|
+
titans_pytorch/titans.py,sha256=cGWJHkOYmIeE6X383mZvyjusECBwbplVvK0cfgfhBxg,18634
|
|
5
|
+
titans_pytorch-0.1.2.dist-info/METADATA,sha256=FWq5JIp1WY9dYpzatfGzfkcAGQFk-mEPwxF0wCrbM5w,4684
|
|
6
|
+
titans_pytorch-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.1.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.1.2.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=7PHBCbeB1LhHY5s3zAyYF0L3Mm7CNy4TOBbcpLX6bNE,17686
|
|
4
|
-
titans_pytorch/titans.py,sha256=L3Mu6pDnimD4MNn_832trFEJgXOPjxSdTrB9jiSUSTk,18533
|
|
5
|
-
titans_pytorch-0.1.0.dist-info/METADATA,sha256=LuWDzv-NbGxYKOMThN_WKQWDueyIsOAMSwwiE_BDraI,4595
|
|
6
|
-
titans_pytorch-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.1.0.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|