x-transformers 2.0.1__tar.gz → 2.0.2__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (55) hide show
  1. {x_transformers-2.0.1 → x_transformers-2.0.2}/PKG-INFO +2 -1
  2. {x_transformers-2.0.1 → x_transformers-2.0.2}/pyproject.toml +6 -2
  3. {x_transformers-2.0.1 → x_transformers-2.0.2}/train_parity.py +16 -8
  4. {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/x_transformers.py +8 -2
  5. {x_transformers-2.0.1 → x_transformers-2.0.2}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.0.1 → x_transformers-2.0.2}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.0.1 → x_transformers-2.0.2}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.0.1 → x_transformers-2.0.2}/.gitignore +0 -0
  9. {x_transformers-2.0.1 → x_transformers-2.0.2}/LICENSE +0 -0
  10. {x_transformers-2.0.1 → x_transformers-2.0.2}/README.md +0 -0
  11. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/all-attention.png +0 -0
  12. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/attention-on-attention.png +0 -0
  13. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/cosine-sim-attention.png +0 -0
  14. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/deepnorm.png +0 -0
  15. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/dynamic-pos-bias-linear.png +0 -0
  16. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/dynamic-pos-bias-log.png +0 -0
  17. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  18. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/dynamic-pos-bias.png +0 -0
  19. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/enhanced-recurrence.png +0 -0
  20. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/fcm.png +0 -0
  21. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/ffglu.png +0 -0
  22. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/flash-attention.png +0 -0
  23. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/gate_values.png +0 -0
  24. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/gating.png +0 -0
  25. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/length-extrapolation-scale.png +0 -0
  26. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/macaron-1.png +0 -0
  27. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/macaron-2.png +0 -0
  28. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/memory-transformer.png +0 -0
  29. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/normformer.png +0 -0
  30. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/pia.png +0 -0
  31. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/qknorm-analysis.png +0 -0
  32. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/resi_dual.png +0 -0
  33. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/residual_attn.png +0 -0
  34. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/rezero.png +0 -0
  35. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/rotary.png +0 -0
  36. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/sandwich-2.png +0 -0
  37. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/sandwich.png +0 -0
  38. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/sandwich_norm.png +0 -0
  39. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/scalenorm.png +0 -0
  40. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/talking-heads.png +0 -0
  41. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/topk-attention.png +0 -0
  42. {x_transformers-2.0.1 → x_transformers-2.0.2}/images/xval.png +0 -0
  43. {x_transformers-2.0.1 → x_transformers-2.0.2}/tests/test_x_transformers.py +0 -0
  44. {x_transformers-2.0.1 → x_transformers-2.0.2}/train_copy.py +0 -0
  45. {x_transformers-2.0.1 → x_transformers-2.0.2}/train_enwik8.py +0 -0
  46. {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/__init__.py +0 -0
  47. {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/attend.py +0 -0
  48. {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/autoregressive_wrapper.py +0 -0
  49. {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/continuous.py +0 -0
  50. {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/dpo.py +0 -0
  51. {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/multi_input.py +0 -0
  52. {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/neo_mlp.py +0 -0
  53. {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
  54. {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  55. {x_transformers-2.0.1 → x_transformers-2.0.2}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.0.1
3
+ Version: 2.0.2
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -40,6 +40,7 @@ Requires-Dist: loguru
40
40
  Requires-Dist: packaging>=21.0
41
41
  Requires-Dist: torch>=2.0
42
42
  Provides-Extra: examples
43
+ Requires-Dist: lion-pytorch; extra == 'examples'
43
44
  Requires-Dist: torchvision; extra == 'examples'
44
45
  Requires-Dist: tqdm; extra == 'examples'
45
46
  Provides-Extra: test
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.0.1"
3
+ version = "2.0.2"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -34,7 +34,11 @@ Homepage = "https://pypi.org/project/x-transformers/"
34
34
  Repository = "https://github.com/lucidrains/x-transformers"
35
35
 
36
36
  [project.optional-dependencies]
37
- examples = ["tqdm", "torchvision"]
37
+ examples = [
38
+ "lion-pytorch",
39
+ "tqdm",
40
+ "torchvision"
41
+ ]
38
42
 
39
43
  test = [
40
44
  "pytest",
@@ -7,12 +7,16 @@ from x_transformers import TransformerWrapper, Decoder
7
7
 
8
8
  # constants
9
9
 
10
- NUM_BATCHES = 100000
11
10
  BATCH_SIZE = 256
12
11
  LEARNING_RATE = 3e-4
13
12
  EVAL_EVERY = 500
14
- TRAIN_MAX_LENGTH = 64
13
+
15
14
  EVAL_LENGTHS = (16, 32, 64, 128, 256, 512)
15
+ TRAIN_MAX_LENGTH = EVAL_LENGTHS[-2]
16
+
17
+ LOSS_THRES_INCREASE_LEN = 1e-3
18
+ MEET_CRITERIA_THRES_INCREASE_LEN = 10
19
+
16
20
  HYBRIDIZE_WITH_RNN = True
17
21
 
18
22
  # rnn for fully resolving state tracking by hybridization
@@ -28,6 +32,7 @@ if HYBRIDIZE_WITH_RNN:
28
32
 
29
33
  decoder_kwargs = dict(
30
34
  attn_hybrid_fold_axial_dim = 4, # even if recurrence is every 4 tokens, can generalize for parity
35
+ attn_hybrid_learned_mix = True,
31
36
  attn_hybrid_module = GRU(dim, dim_head * heads, batch_first = True)
32
37
  )
33
38
 
@@ -48,7 +53,9 @@ model = TransformerWrapper(
48
53
 
49
54
  # optimizer
50
55
 
51
- adam = optim.Adam(model.parameters(), lr = LEARNING_RATE)
56
+ from lion_pytorch.cautious_lion import Lion
57
+
58
+ optimizer = Lion(model.parameters(), lr = LEARNING_RATE, cautious_factor = 0.1)
52
59
 
53
60
  # data generator
54
61
 
@@ -73,7 +80,8 @@ meet_criteria = 0
73
80
  train_seq_len = 1
74
81
  stop_length = EVAL_LENGTHS[-2]
75
82
 
76
- with tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training') as pbar:
83
+ with tqdm.tqdm(mininterval = 10., desc = 'training') as pbar:
84
+
77
85
  while train_seq_len < stop_length:
78
86
  model.train()
79
87
 
@@ -90,12 +98,12 @@ with tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training') as pbar
90
98
  last_loss = loss[:, -1].mean()
91
99
  loss.mean().backward()
92
100
 
93
- if last_loss.item() < 0.001:
101
+ if last_loss.item() < LOSS_THRES_INCREASE_LEN:
94
102
  meet_criteria += 1
95
103
  else:
96
104
  meet_criteria = 0
97
105
 
98
- if meet_criteria >= 10:
106
+ if meet_criteria >= MEET_CRITERIA_THRES_INCREASE_LEN:
99
107
  meet_criteria = 0
100
108
  train_seq_len += 1
101
109
  print(f'criteria met, incrementing to {train_seq_len}')
@@ -103,8 +111,8 @@ with tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training') as pbar
103
111
  print(f'({train_seq_len})| {i}: {last_loss.item()}')
104
112
  torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
105
113
 
106
- adam.step()
107
- adam.zero_grad()
114
+ optimizer.step()
115
+ optimizer.zero_grad()
108
116
 
109
117
  last_step = train_seq_len == stop_length
110
118
 
@@ -1204,6 +1204,7 @@ class Attention(Module):
1204
1204
  hybrid_module: Module | None = None,
1205
1205
  hybrid_mask_kwarg: str | None = None,
1206
1206
  hybrid_fold_axial_dim: int | None = None,
1207
+ hybrid_learned_mix = False,
1207
1208
  one_kv_head = False,
1208
1209
  kv_heads = None,
1209
1210
  value_dim_head = None,
@@ -1446,7 +1447,7 @@ class Attention(Module):
1446
1447
 
1447
1448
  if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
1448
1449
  hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
1449
- hybrid_mix = LinearNoBias(dim, heads)
1450
+ hybrid_mix = LinearNoBias(dim, heads) if hybrid_learned_mix else None
1450
1451
 
1451
1452
  hybrid_norms = ModuleList([
1452
1453
  MultiheadRMSNorm(dim_head, heads = heads),
@@ -1779,7 +1780,12 @@ class Attention(Module):
1779
1780
  out = out_norm(out)
1780
1781
  hybrid_out = hybrid_out_norm(hybrid_out)
1781
1782
 
1782
- out = 0.5 * (out + hybrid_out)
1783
+ if exists(self.hybrid_mix):
1784
+ mix = self.hybrid_mix(x)
1785
+ mix = rearrange(mix, 'b n h -> b h n 1')
1786
+ out = out.lerp(hybrid_out, mix.sigmoid())
1787
+ else:
1788
+ out = 0.5 * (out + hybrid_out)
1783
1789
 
1784
1790
  # merge heads
1785
1791
 
File without changes
File without changes