x-transformers 1.29.1__py3-none-any.whl → 1.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -1,5 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  from functools import partial
2
- from typing import Optional, Tuple
4
+ from typing import Tuple
3
5
 
4
6
  import torch
5
7
  from torch import nn, einsum, Tensor
@@ -16,10 +18,10 @@ from einops import rearrange, repeat
16
18
 
17
19
  @dataclass
18
20
  class Intermediates:
19
- qk_similarities: Optional[Tensor] = None
20
- pre_softmax_attn: Optional[Tensor] = None
21
- post_softmax_attn: Optional[Tensor] = None
22
- cached_kv: Optional[Tuple[Tensor, Tensor]] = None
21
+ qk_similarities: Tensor | None = None
22
+ pre_softmax_attn: Tensor | None = None
23
+ post_softmax_attn: Tensor | None = None
24
+ cached_kv: Tuple[Tensor, Tensor] | None = None
23
25
 
24
26
  def to_tuple(self):
25
27
  return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
@@ -1,5 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  from math import ceil, log
2
- from typing import Optional, Union, Tuple, Callable
4
+ from typing import Tuple, Callable
3
5
 
4
6
  import torch
5
7
  from torch import nn, Tensor
@@ -133,12 +135,12 @@ class AutoregressiveWrapper(Module):
133
135
  seq_len,
134
136
  eos_token = None,
135
137
  temperature = 1.,
136
- prompt_lens: Optional[Tensor] = None,
138
+ prompt_lens: Tensor | None = None,
137
139
  filter_logits_fn: Callable = top_k,
138
140
  restrict_to_max_seq_len = True,
139
- amateur_model: Optional[Union[Module, Tuple[Module]]] = None,
141
+ amateur_model: Module | Tuple[Module] | None = None,
140
142
  filter_kwargs: dict = dict(),
141
- contrastive_decode_kwargs: Union[dict, Tuple[dict]] = dict(
143
+ contrastive_decode_kwargs: dict | Tuple[dict] = dict(
142
144
  beta = 0.5,
143
145
  alpha = 0.1
144
146
  ),
@@ -143,11 +143,11 @@ class ContinuousTransformerWrapper(nn.Module):
143
143
 
144
144
  if return_mems:
145
145
  hiddens = intermediates.hiddens
146
- new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
146
+ new_mems = tuple(t[..., -self.max_mem_len:, :].detach() for t in hiddens)
147
147
  return out, new_mems
148
148
 
149
149
  if return_attn:
150
- attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
150
+ attn_maps = tuple(t.post_softmax_attn for t in intermediates.attn_intermediates)
151
151
  return out, attn_maps
152
152
 
153
153
  return out
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import math
2
4
  from random import random
3
5
  from packaging import version
@@ -11,7 +13,7 @@ from torch.cuda.amp import autocast
11
13
  from functools import partial, wraps
12
14
  from collections import namedtuple
13
15
  from dataclasses import dataclass
14
- from typing import List, Dict, Tuple, Callable, Optional, Union
16
+ from typing import List, Dict, Tuple, Callable
15
17
 
16
18
  from einops import rearrange, repeat, reduce, pack, unpack
17
19
  from einops.layers.torch import Rearrange
@@ -25,13 +27,13 @@ DEFAULT_DIM_HEAD = 64
25
27
 
26
28
  @dataclass
27
29
  class LayerIntermediates:
28
- hiddens: Optional[List[Tensor]] = None # all hiddens, before the final norm (in pre-norm architecture)
29
- last_hidden: Optional[Tensor] = None # very last hidden after all attention layers, after the final norm
30
- attn_intermediates: Optional[List[Intermediates]] = None
31
- layer_hiddens: Optional[List[Tensor]] = None
32
- attn_z_loss: Optional[Tensor] = None
33
- mems: Optional[Tensor] = None
34
- memory_tokens: Optional[Tensor] = None
30
+ hiddens: List[Tensor] | None = None # all hiddens, before the final norm (in pre-norm architecture)
31
+ last_hidden: Tensor | None = None # very last hidden after all attention layers, after the final norm
32
+ attn_intermediates: List[Intermediates] | None = None
33
+ layer_hiddens: List[Tensor] | None = None
34
+ attn_z_loss: Tensor | None = None
35
+ mems: Tensor | None = None
36
+ memory_tokens: Tensor | None = None
35
37
 
36
38
  # helpers
37
39
 
@@ -140,7 +142,7 @@ def init_zero_(layer):
140
142
  # keyword argument helpers
141
143
 
142
144
  def pick_and_pop(keys, d):
143
- values = list(map(lambda key: d.pop(key), keys))
145
+ values = tuple(d.pop(key) for key in keys)
144
146
  return dict(zip(keys, values))
145
147
 
146
148
  def group_dict_by_key(cond, d):
@@ -149,7 +151,7 @@ def group_dict_by_key(cond, d):
149
151
  match = bool(cond(key))
150
152
  ind = int(not match)
151
153
  return_val[ind][key] = d[key]
152
- return (*return_val,)
154
+ return tuple(return_val)
153
155
 
154
156
  def string_begins_with(prefix, str):
155
157
  return str.startswith(prefix)
@@ -159,7 +161,8 @@ def group_by_key_prefix(prefix, d):
159
161
 
160
162
  def groupby_prefix_and_trim(prefix, d):
161
163
  kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
162
- kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
164
+ prefix_len = len(prefix)
165
+ kwargs_without_prefix = {key[prefix_len:]: value for key, value in kwargs_with_prefix.items()}
163
166
  return kwargs_without_prefix, kwargs
164
167
 
165
168
  # structured dropout, more effective than traditional attention dropouts
@@ -455,7 +458,6 @@ class RotaryEmbedding(Module):
455
458
 
456
459
  return freqs, scale
457
460
 
458
-
459
461
  def rotate_half(x):
460
462
  x = rearrange(x, '... (j d) -> ... j d', j = 2)
461
463
  x1, x2 = x.unbind(dim = -2)
@@ -572,8 +574,8 @@ class GRUGating(Module):
572
574
  def shift(t, amount, mask = None):
573
575
  if amount == 0:
574
576
  return t
575
- else:
576
- amount = min(amount, t.shape[1])
577
+
578
+ amount = min(amount, t.shape[1])
577
579
 
578
580
  if exists(mask):
579
581
  t = t.masked_fill(~mask[..., None], 0.)
@@ -597,6 +599,23 @@ class ShiftTokens(Module):
597
599
  x = torch.cat((*segments_to_shift, *rest), dim = -1)
598
600
  return self.fn(x, **kwargs)
599
601
 
602
+ # post branch operator
603
+
604
+ class LayerScale(Module):
605
+ def __init__(self, fn: Module, dim, init_value = 0.):
606
+ super().__init__()
607
+ self.fn = fn
608
+ self.gamma = nn.Parameter(torch.ones(dim) * init_value)
609
+
610
+ def forward(self, x, **kwargs):
611
+ out = self.fn(x, **kwargs)
612
+
613
+ if isinstance(out, Tensor):
614
+ return out * self.gamma
615
+
616
+ out, *rest = out
617
+ return out * self.gamma, *rest
618
+
600
619
  # feedforward
601
620
 
602
621
  class GLU(Module):
@@ -817,7 +836,7 @@ class Attention(Module):
817
836
  mem = None,
818
837
  mem_mask = None,
819
838
  return_intermediates = False,
820
- cache: Optional[Intermediates] = None,
839
+ cache: Intermediates | None = None,
821
840
  ):
822
841
  b, n, h, kv_h, head_scale, num_mem_kv, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context)
823
842
 
@@ -1024,11 +1043,11 @@ class AttentionLayers(Module):
1024
1043
  rotary_interpolation_factor = 1.,
1025
1044
  rotary_xpos_scale_base = 512,
1026
1045
  rotary_base_rescale_factor = 1.,
1027
- custom_layers = None,
1046
+ weight_tie_layers = False,
1047
+ custom_layers: Tuple[str] | None = None,
1048
+ layers_execute_order: Tuple[int] | None = None,
1028
1049
  sandwich_coef = None,
1029
1050
  par_ratio = None,
1030
- weight_tie_layers = False, # Albert - https://arxiv.org/abs/1909.11942
1031
- layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
1032
1051
  residual_attn = False,
1033
1052
  cross_residual_attn = False,
1034
1053
  macaron = False,
@@ -1045,6 +1064,8 @@ class AttentionLayers(Module):
1045
1064
  layer_dropout = 0.,
1046
1065
  cross_attn_tokens_dropout = 0.,
1047
1066
  disable_abs_pos_emb = None,
1067
+ use_layerscale = False,
1068
+ layerscale_init_value = 0.,
1048
1069
  **kwargs
1049
1070
  ):
1050
1071
  super().__init__()
@@ -1108,6 +1129,8 @@ class AttentionLayers(Module):
1108
1129
 
1109
1130
  self.cross_attend = cross_attend
1110
1131
 
1132
+ # determine norm
1133
+
1111
1134
  assert (int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm)) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
1112
1135
 
1113
1136
  if use_scalenorm:
@@ -1121,6 +1144,8 @@ class AttentionLayers(Module):
1121
1144
 
1122
1145
  norm_fn = partial(norm_class, dim)
1123
1146
 
1147
+ # determine default block layer type order
1148
+
1124
1149
  if cross_attend and not only_cross:
1125
1150
  default_block = ('a', 'c', 'f')
1126
1151
  elif cross_attend and only_cross:
@@ -1131,6 +1156,13 @@ class AttentionLayers(Module):
1131
1156
  if macaron:
1132
1157
  default_block = ('f',) + default_block
1133
1158
 
1159
+ # determine post branch wrapper
1160
+
1161
+ post_branch_fn = None
1162
+
1163
+ if use_layerscale:
1164
+ post_branch_fn = partial(LayerScale, dim = dim, init_value = layerscale_init_value)
1165
+
1134
1166
  # zero init
1135
1167
 
1136
1168
  if zero_init_branch_output:
@@ -1178,11 +1210,9 @@ class AttentionLayers(Module):
1178
1210
 
1179
1211
  self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
1180
1212
 
1181
- # validate and set the depth
1213
+ # set the depth
1182
1214
 
1183
1215
  depth = default(depth, len(self.layers_execute_order))
1184
- assert depth == len(self.layers_execute_order)
1185
-
1186
1216
  self.depth = depth
1187
1217
 
1188
1218
  # stochastic depth
@@ -1221,6 +1251,9 @@ class AttentionLayers(Module):
1221
1251
  shift_range_lower = -layer_shift_tokens if not causal else 0
1222
1252
  layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
1223
1253
 
1254
+ if exists(post_branch_fn):
1255
+ layer = post_branch_fn(layer)
1256
+
1224
1257
  residual_fn = GRUGating if gate_residual else Residual
1225
1258
  residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1226
1259
 
@@ -1250,8 +1283,8 @@ class AttentionLayers(Module):
1250
1283
  self_attn_kv_mask = None,
1251
1284
  mems = None,
1252
1285
  mem_masks = None,
1253
- seq_start_pos: Optional[Tensor] = None,
1254
- cache: Optional[LayerIntermediates] = None,
1286
+ seq_start_pos: Tensor | None = None,
1287
+ cache: LayerIntermediates | None = None,
1255
1288
  cache_age = 1,
1256
1289
  return_hiddens = False,
1257
1290
  rotary_pos_emb = None
@@ -1643,7 +1676,7 @@ class TransformerWrapper(Module):
1643
1676
  return_attn_z_loss = False,
1644
1677
  attn_z_loss_weight = 1e-4,
1645
1678
  seq_start_pos = None,
1646
- cache: Optional[LayerIntermediates] = None,
1679
+ cache: LayerIntermediates | None = None,
1647
1680
  **kwargs
1648
1681
  ):
1649
1682
  b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
x_transformers/xval.py CHANGED
@@ -176,11 +176,11 @@ class XValTransformerWrapper(nn.Module):
176
176
 
177
177
  if return_mems:
178
178
  hiddens = intermediates.hiddens
179
- new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
179
+ new_mems = tuple(t[..., -self.max_mem_len:, :].detach() for t in hiddens)
180
180
  return out, new_mems
181
181
 
182
182
  if return_attn:
183
- attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
183
+ attn_maps = tuple(t.post_softmax_attn for t in intermediates.attn_intermediates)
184
184
  return out, attn_maps
185
185
 
186
186
  return out
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.29.1
3
+ Version: 1.30.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -14,6 +14,6 @@ Classifier: License :: OSI Approved :: MIT License
14
14
  Classifier: Programming Language :: Python :: 3.6
15
15
  Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
- Requires-Dist: torch >=1.6
18
- Requires-Dist: einops >=0.7.0
17
+ Requires-Dist: torch >=2.0
18
+ Requires-Dist: einops >=0.8.0
19
19
 
@@ -0,0 +1,14 @@
1
+ x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
2
+ x_transformers/attend.py,sha256=Y9eE26I7BM8rGveabhiRhzw_xq9TY61Sp10QC1hX2O8,10192
3
+ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
+ x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
+ x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
+ x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
+ x_transformers/x_transformers.py,sha256=O-W-xnDhh4t3CKWDhmMJNobbIPQ2Vg24CJC8JBDou2M,65970
8
+ x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
+ x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
+ x_transformers-1.30.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.0.dist-info/METADATA,sha256=Pd-ClX7tybO-h7mfYo2B8ncQM3jU3nQjf2yG1QBpORw,661
12
+ x_transformers-1.30.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.0.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
2
- x_transformers/attend.py,sha256=L7vctHJ0PnECohu4cUu8yvY8cUrVyJxHmMFR0RGL0z4,10163
3
- x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
4
- x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
5
- x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
- x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=jj87ALpQpHGgvG1oHn4Z6UDmc1pqkoO6dY7YtY038w8,65269
8
- x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
- x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
10
- x_transformers-1.29.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.29.1.dist-info/METADATA,sha256=4Nnxc5THUI-d21Szj2mPLTlZYF0A9xVjHN4laFiLCIE,661
12
- x_transformers-1.29.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.29.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.29.1.dist-info/RECORD,,