x-transformers 1.23.5__py3-none-any.whl → 1.23.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,19 @@
1
- import torch
2
- from packaging import version
3
-
4
- if version.parse(torch.__version__) >= version.parse('2.0.0'):
5
- from einops._torch_specific import allow_ops_in_compiled_graph
6
- allow_ops_in_compiled_graph()
7
-
8
- from x_transformers.x_transformers import XTransformer, Encoder, Decoder, CrossAttender, Attention, TransformerWrapper, ViTransformerWrapper, ContinuousTransformerWrapper
1
+ from x_transformers.x_transformers import (
2
+ XTransformer,
3
+ Encoder,
4
+ Decoder,
5
+ CrossAttender,
6
+ Attention,
7
+ TransformerWrapper,
8
+ ViTransformerWrapper
9
+ )
9
10
 
10
11
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
11
12
  from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
12
- from x_transformers.continuous_autoregressive_wrapper import ContinuousAutoregressiveWrapper
13
+
14
+ from x_transformers.continuous_autoregressive_wrapper import (
15
+ ContinuousTransformerWrapper,
16
+ ContinuousAutoregressiveWrapper
17
+ )
18
+
13
19
  from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper
@@ -2,11 +2,147 @@ import torch
2
2
  from torch import nn
3
3
  import torch.nn.functional as F
4
4
 
5
+ from x_transformers.x_transformers import (
6
+ AttentionLayers,
7
+ ScaledSinusoidalEmbedding,
8
+ AbsolutePositionalEmbedding
9
+ )
10
+
11
+ # helper functions
12
+
5
13
  def exists(val):
6
14
  return val is not None
7
15
 
16
+ def default(val, d):
17
+ if exists(val):
18
+ return val
19
+ return d() if callable(d) else d
20
+
21
+ # main classes
22
+
23
+ class ContinuousTransformerWrapper(nn.Module):
24
+ def __init__(
25
+ self,
26
+ *,
27
+ max_seq_len,
28
+ attn_layers: AttentionLayers,
29
+ dim_in = None,
30
+ dim_out = None,
31
+ emb_dim = None,
32
+ max_mem_len = 0,
33
+ num_memory_tokens = None,
34
+ post_emb_norm = False,
35
+ emb_dropout = 0.,
36
+ use_abs_pos_emb = True,
37
+ scaled_sinu_pos_emb = False
38
+ ):
39
+ super().__init__()
40
+ dim = attn_layers.dim
41
+
42
+ self.max_seq_len = max_seq_len
43
+
44
+ self.max_mem_len = max_mem_len
45
+
46
+ if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
47
+ self.pos_emb = always(0)
48
+ elif scaled_sinu_pos_emb:
49
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
50
+ else:
51
+ self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
52
+
53
+ self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
54
+ self.emb_dropout = nn.Dropout(emb_dropout)
55
+
56
+ # memory tokens
57
+
58
+ num_memory_tokens = default(num_memory_tokens, 0)
59
+ self.has_memory_tokens = num_memory_tokens > 0
60
+
61
+ if num_memory_tokens > 0:
62
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
63
+
64
+ # attention layers
65
+
66
+ self.attn_layers = attn_layers
67
+
68
+ # project in and out
69
+
70
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
71
+ self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
72
+
73
+ def forward(
74
+ self,
75
+ x,
76
+ return_embeddings = False,
77
+ return_intermediates = False,
78
+ return_mems = False,
79
+ mask = None,
80
+ return_attn = False,
81
+ mems = None,
82
+ pos = None,
83
+ prepend_embeds = None,
84
+ **kwargs
85
+ ):
86
+ batch = x.shape[0]
87
+
88
+ x = self.project_in(x)
89
+ x = x + self.pos_emb(x, pos = pos)
90
+
91
+ x = self.post_emb_norm(x)
92
+
93
+ # memory tokens
94
+
95
+ if self.has_memory_tokens:
96
+ m = repeat(self.memory_tokens, 'm d -> b m d', b = batch)
97
+ x, mem_ps = pack([m, x], 'b * d')
98
+
99
+ if exists(mask):
100
+ num_mems = m.shape[-2]
101
+ mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
102
+
103
+ # whether to append embeds, as in PaLI, for image embeddings
104
+
105
+ if exists(prepend_embeds):
106
+ _, prepend_dim = prepend_embeds.shape[1:]
107
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
108
+
109
+ x = torch.cat((prepend_embeds, x), dim = -2)
110
+
111
+ x = self.emb_dropout(x)
112
+
113
+ # attention layers
114
+
115
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
116
+
117
+ # splice out memory tokens
118
+
119
+ if self.has_memory_tokens:
120
+ m, x = unpack(x, mem_ps, 'b * d')
121
+ intermediates.memory_tokens = m
122
+
123
+ out = self.project_out(x) if not return_embeddings else x
124
+
125
+ if return_intermediates:
126
+ return out, intermediates
127
+
128
+ if return_mems:
129
+ hiddens = intermediates.hiddens
130
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
131
+ return out, new_mems
132
+
133
+ if return_attn:
134
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
135
+ return out, attn_maps
136
+
137
+ return out
138
+
8
139
  class ContinuousAutoregressiveWrapper(nn.Module):
9
- def __init__(self, net, ignore_index = -100, pad_value = 0):
140
+ def __init__(
141
+ self,
142
+ net: ContinuousTransformerWrapper,
143
+ ignore_index = -100,
144
+ pad_value = 0
145
+ ):
10
146
  super().__init__()
11
147
  self.net = net
12
148
  self.max_seq_len = net.max_seq_len
@@ -6,7 +6,6 @@ from torch import nn, einsum, Tensor
6
6
  import torch.nn.functional as F
7
7
 
8
8
  from functools import partial, wraps
9
- from inspect import isfunction
10
9
  from collections import namedtuple
11
10
  from dataclasses import dataclass
12
11
  from typing import List, Callable, Optional
@@ -38,7 +37,7 @@ def exists(val):
38
37
  def default(val, d):
39
38
  if exists(val):
40
39
  return val
41
- return d() if isfunction(d) else d
40
+ return d() if callable(d) else d
42
41
 
43
42
  def cast_tuple(val, depth):
44
43
  return val if isinstance(val, tuple) else (val,) * depth
@@ -1340,7 +1339,7 @@ class ViTransformerWrapper(nn.Module):
1340
1339
  *,
1341
1340
  image_size,
1342
1341
  patch_size,
1343
- attn_layers,
1342
+ attn_layers: Encoder,
1344
1343
  channels = 3,
1345
1344
  num_classes = None,
1346
1345
  post_emb_norm = False,
@@ -1348,7 +1347,6 @@ class ViTransformerWrapper(nn.Module):
1348
1347
  emb_dropout = 0.
1349
1348
  ):
1350
1349
  super().__init__()
1351
- assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
1352
1350
  assert divisible_by(image_size, patch_size), 'image dimensions must be divisible by the patch size'
1353
1351
  dim = attn_layers.dim
1354
1352
  num_patches = (image_size // patch_size) ** 2
@@ -1414,7 +1412,7 @@ class TransformerWrapper(nn.Module):
1414
1412
  *,
1415
1413
  num_tokens,
1416
1414
  max_seq_len,
1417
- attn_layers,
1415
+ attn_layers: AttentionLayers,
1418
1416
  emb_dim = None,
1419
1417
  max_mem_len = 0,
1420
1418
  shift_mem_down = 0,
@@ -1431,7 +1429,6 @@ class TransformerWrapper(nn.Module):
1431
1429
  attn_z_loss_weight = 1e-4,
1432
1430
  ):
1433
1431
  super().__init__()
1434
- assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1435
1432
 
1436
1433
  dim = attn_layers.dim
1437
1434
  emb_dim = default(emb_dim, dim)
@@ -1615,124 +1612,6 @@ class TransformerWrapper(nn.Module):
1615
1612
 
1616
1613
  return out
1617
1614
 
1618
- class ContinuousTransformerWrapper(nn.Module):
1619
- def __init__(
1620
- self,
1621
- *,
1622
- max_seq_len,
1623
- attn_layers,
1624
- dim_in = None,
1625
- dim_out = None,
1626
- emb_dim = None,
1627
- max_mem_len = 0,
1628
- num_memory_tokens = None,
1629
- post_emb_norm = False,
1630
- emb_dropout = 0.,
1631
- use_abs_pos_emb = True,
1632
- scaled_sinu_pos_emb = False
1633
- ):
1634
- super().__init__()
1635
- assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1636
-
1637
- dim = attn_layers.dim
1638
-
1639
- self.max_seq_len = max_seq_len
1640
-
1641
- self.max_mem_len = max_mem_len
1642
-
1643
- if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
1644
- self.pos_emb = always(0)
1645
- elif scaled_sinu_pos_emb:
1646
- self.pos_emb = ScaledSinusoidalEmbedding(dim)
1647
- else:
1648
- self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
1649
-
1650
- self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
1651
- self.emb_dropout = nn.Dropout(emb_dropout)
1652
-
1653
- # memory tokens
1654
-
1655
- num_memory_tokens = default(num_memory_tokens, 0)
1656
- self.has_memory_tokens = num_memory_tokens > 0
1657
-
1658
- if num_memory_tokens > 0:
1659
- self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
1660
-
1661
- # attention layers
1662
-
1663
- self.attn_layers = attn_layers
1664
-
1665
- # project in and out
1666
-
1667
- self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
1668
- self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
1669
-
1670
- def forward(
1671
- self,
1672
- x,
1673
- return_embeddings = False,
1674
- return_intermediates = False,
1675
- return_mems = False,
1676
- mask = None,
1677
- return_attn = False,
1678
- mems = None,
1679
- pos = None,
1680
- prepend_embeds = None,
1681
- **kwargs
1682
- ):
1683
- batch = x.shape[0]
1684
-
1685
- x = self.project_in(x)
1686
- x = x + self.pos_emb(x, pos = pos)
1687
-
1688
- x = self.post_emb_norm(x)
1689
-
1690
- # memory tokens
1691
-
1692
- if self.has_memory_tokens:
1693
- m = repeat(self.memory_tokens, 'm d -> b m d', b = batch)
1694
- x, mem_ps = pack([m, x], 'b * d')
1695
-
1696
- if exists(mask):
1697
- num_mems = m.shape[-2]
1698
- mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
1699
-
1700
- # whether to append embeds, as in PaLI, for image embeddings
1701
-
1702
- if exists(prepend_embeds):
1703
- _, prepend_dim = prepend_embeds.shape[1:]
1704
- assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
1705
-
1706
- x = torch.cat((prepend_embeds, x), dim = -2)
1707
-
1708
- x = self.emb_dropout(x)
1709
-
1710
- # attention layers
1711
-
1712
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
1713
-
1714
- # splice out memory tokens
1715
-
1716
- if self.has_memory_tokens:
1717
- m, x = unpack(x, mem_ps, 'b * d')
1718
- intermediates.memory_tokens = m
1719
-
1720
- out = self.project_out(x) if not return_embeddings else x
1721
-
1722
- if return_intermediates:
1723
- return out, intermediates
1724
-
1725
- if return_mems:
1726
- hiddens = intermediates.hiddens
1727
- new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
1728
- return out, new_mems
1729
-
1730
- if return_attn:
1731
- attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1732
- return out, attn_maps
1733
-
1734
- return out
1735
-
1736
1615
  class XTransformer(nn.Module):
1737
1616
  def __init__(
1738
1617
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.23.5
3
+ Version: 1.23.6
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -0,0 +1,12 @@
1
+ x_transformers/__init__.py,sha256=LLbGkiUKu4nR4YOISh-1gPEZYs8R3I2KMwMZHMQ2YkU,538
2
+ x_transformers/attend.py,sha256=hZcz_iijzbEqbXp2_BPEVL-1LoHXmYaHE6e6Oy-7hFE,11263
3
+ x_transformers/autoregressive_wrapper.py,sha256=f2u0usjUfAlXwgTz87O8J8XjGTbsbrx2XEP6K2beSNI,8944
4
+ x_transformers/continuous_autoregressive_wrapper.py,sha256=lGqE5vFaDeuLFc7b-dAQ0hx3H4dHFK_yD4-tZQZ7vqQ,5337
5
+ x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
+ x_transformers/x_transformers.py,sha256=8zfU1iqrd6AwT-L23jkK7tpunxfjLi4HKgbZ6evBWKU,58416
7
+ x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
+ x_transformers-1.23.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
+ x_transformers-1.23.6.dist-info/METADATA,sha256=UnmCziZTo4xYZJjAdgtRyWLytboEAR0JmyK2i0hwMvc,661
10
+ x_transformers-1.23.6.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
+ x_transformers-1.23.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
+ x_transformers-1.23.6.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
2
- x_transformers/attend.py,sha256=hZcz_iijzbEqbXp2_BPEVL-1LoHXmYaHE6e6Oy-7hFE,11263
3
- x_transformers/autoregressive_wrapper.py,sha256=f2u0usjUfAlXwgTz87O8J8XjGTbsbrx2XEP6K2beSNI,8944
4
- x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
5
- x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=xnHu4himfvXeH8HVnFm3JTAGcs22D8H-kzLN9Qrm9-c,62143
7
- x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
- x_transformers-1.23.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
- x_transformers-1.23.5.dist-info/METADATA,sha256=F1WPg9cX9USpJmgvN_NzOqp0dsNn3WcchI-cDgoGKfk,661
10
- x_transformers-1.23.5.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
- x_transformers-1.23.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
- x_transformers-1.23.5.dist-info/RECORD,,