x-transformers 1.32.5__tar.gz → 1.32.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {x_transformers-1.32.5/x_transformers.egg-info → x_transformers-1.32.7}/PKG-INFO +1 -1
  2. {x_transformers-1.32.5 → x_transformers-1.32.7}/setup.py +1 -1
  3. {x_transformers-1.32.5 → x_transformers-1.32.7}/x_transformers/multi_input.py +5 -20
  4. {x_transformers-1.32.5 → x_transformers-1.32.7/x_transformers.egg-info}/PKG-INFO +1 -1
  5. {x_transformers-1.32.5 → x_transformers-1.32.7}/LICENSE +0 -0
  6. {x_transformers-1.32.5 → x_transformers-1.32.7}/README.md +0 -0
  7. {x_transformers-1.32.5 → x_transformers-1.32.7}/setup.cfg +0 -0
  8. {x_transformers-1.32.5 → x_transformers-1.32.7}/tests/test_x_transformers.py +0 -0
  9. {x_transformers-1.32.5 → x_transformers-1.32.7}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.32.5 → x_transformers-1.32.7}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.32.5 → x_transformers-1.32.7}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.32.5 → x_transformers-1.32.7}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.32.5 → x_transformers-1.32.7}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.32.5 → x_transformers-1.32.7}/x_transformers/nonautoregressive_wrapper.py +0 -0
  15. {x_transformers-1.32.5 → x_transformers-1.32.7}/x_transformers/x_transformers.py +0 -0
  16. {x_transformers-1.32.5 → x_transformers-1.32.7}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  17. {x_transformers-1.32.5 → x_transformers-1.32.7}/x_transformers/xval.py +0 -0
  18. {x_transformers-1.32.5 → x_transformers-1.32.7}/x_transformers.egg-info/SOURCES.txt +0 -0
  19. {x_transformers-1.32.5 → x_transformers-1.32.7}/x_transformers.egg-info/dependency_links.txt +0 -0
  20. {x_transformers-1.32.5 → x_transformers-1.32.7}/x_transformers.egg-info/requires.txt +0 -0
  21. {x_transformers-1.32.5 → x_transformers-1.32.7}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.32.5
3
+ Version: 1.32.7
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.32.5',
6
+ version = '1.32.7',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import torch
2
4
  from torch import nn, Tensor
3
5
  from torch.nn import Module, ModuleDict
@@ -15,7 +17,7 @@ from x_transformers.x_transformers import (
15
17
  LayerNorm,
16
18
  always,
17
19
  pad_at_dim,
18
- is_empty
20
+ is_empty,
19
21
  )
20
22
 
21
23
  # helper functions
@@ -43,10 +45,7 @@ class MultiInputTransformerWrapper(Module):
43
45
  post_emb_norm = False,
44
46
  num_memory_tokens = None,
45
47
  memory_tokens_interspersed_every = None,
46
- tie_embedding = False,
47
- logits_dim = None,
48
48
  return_only_embed = False,
49
- num_output_heads = 1,
50
49
  use_abs_pos_emb = True,
51
50
  scaled_sinu_pos_emb = False,
52
51
  emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
@@ -85,23 +84,12 @@ class MultiInputTransformerWrapper(Module):
85
84
  self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
86
85
  self.attn_layers = attn_layers
87
86
 
88
- assert num_output_heads > 0
89
-
90
87
  # output head, usually to logits of num_tokens
91
88
 
92
- logits_dim = default(logits_dim, num_tokens)
93
-
94
- self.has_multiple_heads = False
95
-
96
89
  if return_only_embed:
97
90
  self.to_logits = None
98
- elif tie_embedding:
99
- self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
100
- elif num_output_heads > 1:
101
- self.has_multiple_heads = True
102
- self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
103
91
  else:
104
- self.to_logits = nn.Linear(dim, logits_dim, bias = False)
92
+ self.to_logits = ModuleDict({name: nn.Linear(dim, logits_dim, bias = False) for name, logits_dim in num_tokens.items()})
105
93
 
106
94
  # memory tokens (like [cls]) from Memory Transformers paper
107
95
 
@@ -251,10 +239,7 @@ class MultiInputTransformerWrapper(Module):
251
239
  # projecting to logits
252
240
 
253
241
  if not return_embeddings:
254
- if self.has_multiple_heads:
255
- logits = tuple(fn(x) for fn in self.to_logits)
256
- else:
257
- logits = self.to_logits(x)
242
+ logits = {name: fn(x) for name, fn in self.to_logits.items()}
258
243
 
259
244
  # different returns
260
245
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.32.5
3
+ Version: 1.32.7
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes