x-transformers 2.0.1__tar.gz → 2.0.2__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-2.0.1 → x_transformers-2.0.2}/PKG-INFO +2 -1
- {x_transformers-2.0.1 → x_transformers-2.0.2}/pyproject.toml +6 -2
- {x_transformers-2.0.1 → x_transformers-2.0.2}/train_parity.py +16 -8
- {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/x_transformers.py +8 -2
- {x_transformers-2.0.1 → x_transformers-2.0.2}/.github/FUNDING.yml +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/.gitignore +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/LICENSE +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/README.md +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/all-attention.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/attention-on-attention.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/deepnorm.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/fcm.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/ffglu.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/flash-attention.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/gate_values.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/gating.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/macaron-1.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/macaron-2.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/memory-transformer.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/normformer.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/pia.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/resi_dual.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/residual_attn.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/rezero.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/rotary.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/sandwich-2.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/sandwich.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/sandwich_norm.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/scalenorm.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/talking-heads.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/topk-attention.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/images/xval.png +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/tests/test_x_transformers.py +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/train_copy.py +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/train_enwik8.py +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/__init__.py +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/attend.py +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/continuous.py +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/dpo.py +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.0.
|
3
|
+
Version: 2.0.2
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -40,6 +40,7 @@ Requires-Dist: loguru
|
|
40
40
|
Requires-Dist: packaging>=21.0
|
41
41
|
Requires-Dist: torch>=2.0
|
42
42
|
Provides-Extra: examples
|
43
|
+
Requires-Dist: lion-pytorch; extra == 'examples'
|
43
44
|
Requires-Dist: torchvision; extra == 'examples'
|
44
45
|
Requires-Dist: tqdm; extra == 'examples'
|
45
46
|
Provides-Extra: test
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "x-transformers"
|
3
|
-
version = "2.0.
|
3
|
+
version = "2.0.2"
|
4
4
|
description = "X-Transformers"
|
5
5
|
authors = [
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
@@ -34,7 +34,11 @@ Homepage = "https://pypi.org/project/x-transformers/"
|
|
34
34
|
Repository = "https://github.com/lucidrains/x-transformers"
|
35
35
|
|
36
36
|
[project.optional-dependencies]
|
37
|
-
examples = [
|
37
|
+
examples = [
|
38
|
+
"lion-pytorch",
|
39
|
+
"tqdm",
|
40
|
+
"torchvision"
|
41
|
+
]
|
38
42
|
|
39
43
|
test = [
|
40
44
|
"pytest",
|
@@ -7,12 +7,16 @@ from x_transformers import TransformerWrapper, Decoder
|
|
7
7
|
|
8
8
|
# constants
|
9
9
|
|
10
|
-
NUM_BATCHES = 100000
|
11
10
|
BATCH_SIZE = 256
|
12
11
|
LEARNING_RATE = 3e-4
|
13
12
|
EVAL_EVERY = 500
|
14
|
-
|
13
|
+
|
15
14
|
EVAL_LENGTHS = (16, 32, 64, 128, 256, 512)
|
15
|
+
TRAIN_MAX_LENGTH = EVAL_LENGTHS[-2]
|
16
|
+
|
17
|
+
LOSS_THRES_INCREASE_LEN = 1e-3
|
18
|
+
MEET_CRITERIA_THRES_INCREASE_LEN = 10
|
19
|
+
|
16
20
|
HYBRIDIZE_WITH_RNN = True
|
17
21
|
|
18
22
|
# rnn for fully resolving state tracking by hybridization
|
@@ -28,6 +32,7 @@ if HYBRIDIZE_WITH_RNN:
|
|
28
32
|
|
29
33
|
decoder_kwargs = dict(
|
30
34
|
attn_hybrid_fold_axial_dim = 4, # even if recurrence is every 4 tokens, can generalize for parity
|
35
|
+
attn_hybrid_learned_mix = True,
|
31
36
|
attn_hybrid_module = GRU(dim, dim_head * heads, batch_first = True)
|
32
37
|
)
|
33
38
|
|
@@ -48,7 +53,9 @@ model = TransformerWrapper(
|
|
48
53
|
|
49
54
|
# optimizer
|
50
55
|
|
51
|
-
|
56
|
+
from lion_pytorch.cautious_lion import Lion
|
57
|
+
|
58
|
+
optimizer = Lion(model.parameters(), lr = LEARNING_RATE, cautious_factor = 0.1)
|
52
59
|
|
53
60
|
# data generator
|
54
61
|
|
@@ -73,7 +80,8 @@ meet_criteria = 0
|
|
73
80
|
train_seq_len = 1
|
74
81
|
stop_length = EVAL_LENGTHS[-2]
|
75
82
|
|
76
|
-
with tqdm.tqdm(
|
83
|
+
with tqdm.tqdm(mininterval = 10., desc = 'training') as pbar:
|
84
|
+
|
77
85
|
while train_seq_len < stop_length:
|
78
86
|
model.train()
|
79
87
|
|
@@ -90,12 +98,12 @@ with tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training') as pbar
|
|
90
98
|
last_loss = loss[:, -1].mean()
|
91
99
|
loss.mean().backward()
|
92
100
|
|
93
|
-
if last_loss.item() <
|
101
|
+
if last_loss.item() < LOSS_THRES_INCREASE_LEN:
|
94
102
|
meet_criteria += 1
|
95
103
|
else:
|
96
104
|
meet_criteria = 0
|
97
105
|
|
98
|
-
if meet_criteria >=
|
106
|
+
if meet_criteria >= MEET_CRITERIA_THRES_INCREASE_LEN:
|
99
107
|
meet_criteria = 0
|
100
108
|
train_seq_len += 1
|
101
109
|
print(f'criteria met, incrementing to {train_seq_len}')
|
@@ -103,8 +111,8 @@ with tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training') as pbar
|
|
103
111
|
print(f'({train_seq_len})| {i}: {last_loss.item()}')
|
104
112
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
105
113
|
|
106
|
-
|
107
|
-
|
114
|
+
optimizer.step()
|
115
|
+
optimizer.zero_grad()
|
108
116
|
|
109
117
|
last_step = train_seq_len == stop_length
|
110
118
|
|
@@ -1204,6 +1204,7 @@ class Attention(Module):
|
|
1204
1204
|
hybrid_module: Module | None = None,
|
1205
1205
|
hybrid_mask_kwarg: str | None = None,
|
1206
1206
|
hybrid_fold_axial_dim: int | None = None,
|
1207
|
+
hybrid_learned_mix = False,
|
1207
1208
|
one_kv_head = False,
|
1208
1209
|
kv_heads = None,
|
1209
1210
|
value_dim_head = None,
|
@@ -1446,7 +1447,7 @@ class Attention(Module):
|
|
1446
1447
|
|
1447
1448
|
if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
|
1448
1449
|
hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
|
1449
|
-
hybrid_mix = LinearNoBias(dim, heads)
|
1450
|
+
hybrid_mix = LinearNoBias(dim, heads) if hybrid_learned_mix else None
|
1450
1451
|
|
1451
1452
|
hybrid_norms = ModuleList([
|
1452
1453
|
MultiheadRMSNorm(dim_head, heads = heads),
|
@@ -1779,7 +1780,12 @@ class Attention(Module):
|
|
1779
1780
|
out = out_norm(out)
|
1780
1781
|
hybrid_out = hybrid_out_norm(hybrid_out)
|
1781
1782
|
|
1782
|
-
|
1783
|
+
if exists(self.hybrid_mix):
|
1784
|
+
mix = self.hybrid_mix(x)
|
1785
|
+
mix = rearrange(mix, 'b n h -> b h n 1')
|
1786
|
+
out = out.lerp(hybrid_out, mix.sigmoid())
|
1787
|
+
else:
|
1788
|
+
out = 0.5 * (out + hybrid_out)
|
1783
1789
|
|
1784
1790
|
# merge heads
|
1785
1791
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|