x-transformers 1.35.3__py3-none-any.whl → 1.37.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -317,7 +317,9 @@ class AutoregressiveWrapper(Module):
317
317
  **kwargs
318
318
  )
319
319
 
320
- loss = F.cross_entropy(
320
+ loss_fn = F.cross_entropy if not self.net.is_log_prob else F.nll_loss
321
+
322
+ loss = loss_fn(
321
323
  rearrange(logits, 'b n c -> b c n'),
322
324
  target,
323
325
  ignore_index = ignore_index
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
+ from typing import Callable
2
3
 
3
4
  import math
4
- from random import random
5
+ from random import random, randrange
5
6
  from packaging import version
6
7
 
7
8
  import torch
@@ -12,8 +13,8 @@ from torch.amp import autocast
12
13
 
13
14
  from functools import partial, wraps
14
15
  from collections import namedtuple
16
+ from contextlib import nullcontext
15
17
  from dataclasses import dataclass
16
- from typing import List, Dict, Tuple, Callable
17
18
 
18
19
  from einops import rearrange, repeat, reduce, pack, unpack
19
20
  from einops.layers.torch import Rearrange
@@ -27,14 +28,16 @@ DEFAULT_DIM_HEAD = 64
27
28
 
28
29
  @dataclass
29
30
  class LayerIntermediates:
30
- hiddens: List[Tensor] | None = None # all hiddens, before the final norm (in pre-norm architecture)
31
+ hiddens: list[Tensor] | None = None # all hiddens, before the final norm (in pre-norm architecture)
31
32
  last_hidden: Tensor | None = None # very last hidden after all attention layers, after the final norm
32
- attn_intermediates: List[Intermediates] | None = None
33
- layer_hiddens: List[Tensor] | None = None
33
+ attn_intermediates: list[Intermediates] | None = None
34
+ layer_hiddens: list[Tensor] | None = None
34
35
  attn_z_loss: Tensor | None = None
35
36
  mems: Tensor | None = None
36
37
  memory_tokens: Tensor | None = None
37
38
 
39
+ LinearNoBias = partial(nn.Linear, bias = False)
40
+
38
41
  # helpers
39
42
 
40
43
  def exists(val):
@@ -91,6 +94,9 @@ def Sequential(*modules):
91
94
 
92
95
  # tensor helpers
93
96
 
97
+ def log(t, eps = 1e-20):
98
+ return t.clamp(min = eps).log()
99
+
94
100
  def max_neg_value(tensor):
95
101
  return -torch.finfo(tensor.dtype).max
96
102
 
@@ -113,7 +119,7 @@ def masked_mean(t, mask = None, dim = 1):
113
119
  den = mask.sum(dim = dim).clamp(min = 1.)
114
120
  return num / den
115
121
 
116
- def pad_at_dim(t, pad: Tuple[int, int], dim = -1, value = 0.):
122
+ def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
117
123
  if pad == (0, 0):
118
124
  return t
119
125
 
@@ -130,7 +136,7 @@ def or_reduce(masks):
130
136
  # auxiliary loss helpers
131
137
 
132
138
  def calc_z_loss(
133
- pre_softmax_attns: List[Tensor],
139
+ pre_softmax_attns: list[Tensor],
134
140
  mask = None,
135
141
  weight = 1.
136
142
  ):
@@ -610,7 +616,7 @@ class AdaptiveLayerNorm(Module):
610
616
  dim_condition = default(dim_condition, dim)
611
617
 
612
618
  self.ln = nn.LayerNorm(dim, elementwise_affine = False)
613
- self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
619
+ self.to_gamma = LinearNoBias(dim_condition, dim)
614
620
  nn.init.zeros_(self.to_gamma.weight)
615
621
 
616
622
  def forward(self, x, *, condition):
@@ -665,7 +671,7 @@ class AdaptiveRMSNorm(Module):
665
671
  self.scale = dim ** 0.5
666
672
  dim_condition = default(dim_condition, dim)
667
673
 
668
- self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
674
+ self.to_gamma = LinearNoBias(dim_condition, dim)
669
675
  nn.init.zeros_(self.to_gamma.weight)
670
676
 
671
677
  def forward(self, x, *, condition):
@@ -748,7 +754,7 @@ class ShiftTokens(Module):
748
754
  feats_per_shift = x.shape[-1] // segments
749
755
  splitted = x.split(feats_per_shift, dim = -1)
750
756
  segments_to_shift, rest = splitted[:segments], splitted[segments:]
751
- segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
757
+ segments_to_shift = [shift(*args, mask = mask) for args in zip(segments_to_shift, shifts)]
752
758
  x = torch.cat((*segments_to_shift, *rest), dim = -1)
753
759
  return self.fn(x, **kwargs)
754
760
 
@@ -816,7 +822,7 @@ class ConcatCombine(Module):
816
822
  def __init__(self, dim, prev_layer_ind):
817
823
  super().__init__()
818
824
  self.prev_layer_ind = prev_layer_ind
819
- self.combine = nn.Linear(dim * 2, dim, bias = False)
825
+ self.combine = LinearNoBias(dim * 2, dim)
820
826
 
821
827
  def forward(self, x, prev_layers: list[Tensor]):
822
828
  skip = prev_layers[self.prev_layer_ind]
@@ -956,17 +962,17 @@ class Attention(Module):
956
962
  v_dim = value_dim_head * kv_heads
957
963
  out_dim = value_dim_head * heads
958
964
 
959
- self.to_q = nn.Linear(dim, q_dim, bias = False)
960
- self.to_k = nn.Linear(dim_kv, k_dim, bias = False)
965
+ self.to_q = LinearNoBias(dim, q_dim)
966
+ self.to_k = LinearNoBias(dim_kv, k_dim)
961
967
 
962
968
  # shared key / values, for further memory savings during inference
963
969
 
964
970
  assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
965
- self.to_v = nn.Linear(dim_kv, v_dim, bias = False) if not shared_kv else None
971
+ self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
966
972
 
967
973
  # relations projection from tp-attention
968
974
 
969
- self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
975
+ self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
970
976
 
971
977
  # add GLU gating for aggregated values, from alphafold2
972
978
 
@@ -1062,7 +1068,7 @@ class Attention(Module):
1062
1068
  # output dimension by default same as input, but can be overridden
1063
1069
 
1064
1070
  dim_out = default(dim_out, dim)
1065
- self.to_out = nn.Sequential(nn.Linear(out_dim, dim_out * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim_out, bias = False)
1071
+ self.to_out = nn.Sequential(LinearNoBias(out_dim, dim_out * 2), nn.GLU()) if on_attn else LinearNoBias(out_dim, dim_out)
1066
1072
 
1067
1073
  # whether to rotate positions into values, for absolute positions in addition to relative
1068
1074
 
@@ -1108,7 +1114,7 @@ class Attention(Module):
1108
1114
 
1109
1115
  q = rearrange(q, 'b n (h d) -> b h n d', h = h)
1110
1116
 
1111
- k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h), (k, v, r))
1117
+ k, v, r = tuple(maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h) for t in (k, v, r))
1112
1118
 
1113
1119
  if exists(cache):
1114
1120
  ck, cv = cache.cached_kv
@@ -1163,12 +1169,12 @@ class Attention(Module):
1163
1169
 
1164
1170
  # i, j determined for relative positional bias, excluding memory key / values
1165
1171
 
1166
- i, j = map(lambda t: t.shape[-2], (q, k))
1172
+ i, j = tuple(t.shape[-2] for t in (q, k))
1167
1173
 
1168
1174
  # maybe append memory key / values
1169
1175
 
1170
1176
  if num_mem_kv > 0:
1171
- mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
1177
+ mem_k, mem_v = tuple(repeat(t, 'h n d -> b h n d', b = b) for t in (self.mem_k, self.mem_v))
1172
1178
 
1173
1179
  if self.qk_norm:
1174
1180
  mem_k = l2norm(mem_k)
@@ -1301,8 +1307,8 @@ class AttentionLayers(Module):
1301
1307
  rotary_xpos_scale_base = 512,
1302
1308
  rotary_base_rescale_factor = 1.,
1303
1309
  weight_tie_layers = False,
1304
- custom_layers: Tuple[str, ...] | None = None,
1305
- layers_execute_order: Tuple[int, ...] | None = None,
1310
+ custom_layers: tuple[str, ...] | None = None,
1311
+ layers_execute_order: tuple[int, ...] | None = None,
1306
1312
  sandwich_coef = None,
1307
1313
  par_ratio = None,
1308
1314
  residual_attn = False,
@@ -1463,7 +1469,7 @@ class AttentionLayers(Module):
1463
1469
 
1464
1470
  if self.need_condition and adaptive_condition_mlp:
1465
1471
  self.adaptive_mlp = nn.Sequential(
1466
- nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
1472
+ LinearNoBias(dim_condition, dim_condition * dim_condition_mult),
1467
1473
  nn.SiLU()
1468
1474
  )
1469
1475
 
@@ -1634,7 +1640,7 @@ class AttentionLayers(Module):
1634
1640
  return_hiddens = False,
1635
1641
  rotary_pos_emb = None,
1636
1642
  condition = None,
1637
- layers_execute_order: Tuple[int, ...] | None = None
1643
+ layers_execute_order: tuple[int, ...] | None = None
1638
1644
  ):
1639
1645
  assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1640
1646
  assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
@@ -1972,7 +1978,7 @@ class TransformerWrapper(Module):
1972
1978
  num_tokens,
1973
1979
  max_seq_len,
1974
1980
  attn_layers: AttentionLayers,
1975
- embed_num_tokens: Dict[str, int] = dict(),
1981
+ embed_num_tokens: dict[str, int] = dict(),
1976
1982
  emb_dim = None,
1977
1983
  max_mem_len = 0,
1978
1984
  shift_mem_down = 0,
@@ -1987,12 +1993,16 @@ class TransformerWrapper(Module):
1987
1993
  use_abs_pos_emb = True,
1988
1994
  scaled_sinu_pos_emb = False,
1989
1995
  l2norm_embed = False,
1990
- emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
1996
+ recycling = False, # from Jumper et al. - Alphafold2
1997
+ train_max_recycle_steps = 4, # saw a benefit for language modeling up to 3 recycling steps, so let's default this to 4
1998
+ emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
1991
1999
  attn_z_loss_weight = 1e-4,
1992
2000
  average_pool_embed = False,
1993
2001
  use_cls_token = False,
1994
2002
  squeeze_out_last_dim = False,
1995
2003
  token_emb: TokenEmbedding | None = None,
2004
+ mixture_of_softmax = False,
2005
+ mixture_of_softmax_k = 4,
1996
2006
  ):
1997
2007
  super().__init__()
1998
2008
 
@@ -2044,6 +2054,13 @@ class TransformerWrapper(Module):
2044
2054
 
2045
2055
  assert at_most_one_of(average_pool_embed, use_cls_token)
2046
2056
 
2057
+ # maybe recycling
2058
+
2059
+ self.recycling = recycling
2060
+ self.recycled_proj = LinearNoBias(dim, dim) if recycling else None
2061
+
2062
+ self.train_max_recycle_steps = train_max_recycle_steps
2063
+
2047
2064
  # classic cls token from the bert days
2048
2065
 
2049
2066
  self.cls_token = None
@@ -2056,21 +2073,37 @@ class TransformerWrapper(Module):
2056
2073
 
2057
2074
  self.average_pool_embed = average_pool_embed
2058
2075
 
2076
+ # output type
2077
+
2078
+ self.is_log_prob = mixture_of_softmax
2079
+
2080
+ self.to_mixture = None
2081
+ self.combine_mixture = None
2082
+
2083
+ if mixture_of_softmax:
2084
+ assert num_output_heads == 1
2085
+
2086
+ self.to_mixture = Sequential(
2087
+ LinearNoBias(dim, dim * mixture_of_softmax_k),
2088
+ Rearrange('... (k d) -> ... k d', k = mixture_of_softmax_k)
2089
+ )
2090
+
2091
+ self.combine_mixture = LinearNoBias(dim, mixture_of_softmax_k)
2092
+
2059
2093
  # output head, usually to logits of num_tokens
2060
2094
 
2061
2095
  logits_dim = default(logits_dim, num_tokens)
2062
2096
 
2063
- self.has_multiple_heads = False
2097
+ self.has_multiple_heads = num_output_heads > 1
2064
2098
 
2065
2099
  if return_only_embed:
2066
2100
  self.to_logits = None
2067
2101
  elif tie_embedding:
2068
2102
  self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
2069
2103
  elif num_output_heads > 1:
2070
- self.has_multiple_heads = True
2071
- self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
2104
+ self.to_logits = ModuleList([LinearNoBias(dim, logits_dim) for _ in range(num_output_heads)])
2072
2105
  else:
2073
- self.to_logits = nn.Linear(dim, logits_dim, bias = False)
2106
+ self.to_logits = LinearNoBias(dim, logits_dim)
2074
2107
 
2075
2108
  # memory tokens (like [cls]) from Memory Transformers paper
2076
2109
 
@@ -2087,7 +2120,7 @@ class TransformerWrapper(Module):
2087
2120
 
2088
2121
  # whether can do cached kv decoding
2089
2122
 
2090
- self.can_cache_kv = self.num_memory_tokens == 0
2123
+ self.can_cache_kv = self.num_memory_tokens == 0 and not recycling
2091
2124
  self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
2092
2125
 
2093
2126
  def init_(self):
@@ -2110,10 +2143,11 @@ class TransformerWrapper(Module):
2110
2143
  return_attn = False,
2111
2144
  mems = None,
2112
2145
  mem_masks = None,
2146
+ recycle_steps = None,
2113
2147
  pos = None,
2114
2148
  prepend_embeds = None,
2115
2149
  prepend_mask = None,
2116
- embed_ids: Dict[str, Tensor] = dict(),
2150
+ embed_ids: dict[str, Tensor] = dict(),
2117
2151
  sum_embeds = None,
2118
2152
  return_attn_z_loss = False,
2119
2153
  attn_z_loss_weight = 1e-4,
@@ -2215,11 +2249,37 @@ class TransformerWrapper(Module):
2215
2249
  if exists(mem_every):
2216
2250
  x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2217
2251
 
2252
+ # handle maybe shifting of memories
2253
+
2218
2254
  if self.shift_mem_down and exists(mems):
2219
2255
  mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
2220
2256
  mems = [*mems_r, *mems_l]
2221
2257
 
2222
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
2258
+ # attention layers
2259
+
2260
+ if not self.recycling:
2261
+ # regular
2262
+
2263
+ attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
2264
+
2265
+ else:
2266
+ # recycling
2267
+
2268
+ recycle_steps = default(recycle_steps, (randrange(self.train_max_recycle_steps) + 1) if self.training else None)
2269
+ assert exists(recycle_steps) and recycle_steps > 0, '`recycle_steps` must be provided on forward if recycling is turned on and not training'
2270
+
2271
+ for i in range(recycle_steps):
2272
+ first_step = i == 0
2273
+ last_step = i == (recycle_steps - 1)
2274
+
2275
+ context = nullcontext if last_step else torch.no_grad
2276
+
2277
+ with context():
2278
+ maybe_recycled = self.recycled_proj(attended.detach()) if not first_step else 0.
2279
+
2280
+ attended, intermediates = self.attn_layers(x + maybe_recycled, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
2281
+
2282
+ x = attended
2223
2283
 
2224
2284
  # handle memories post-attention
2225
2285
 
@@ -2244,6 +2304,14 @@ class TransformerWrapper(Module):
2244
2304
  if exists(self.cls_token):
2245
2305
  x, _ = unpack(x, cls_packed_shape, 'b * d')
2246
2306
 
2307
+ # handle expansion to mixture if needed (for mixture of softmax)
2308
+
2309
+ combine_mixture = None
2310
+
2311
+ if exists(self.to_mixture):
2312
+ combine_mixture = self.combine_mixture(x).softmax(dim = -1)
2313
+ x = self.to_mixture(x)
2314
+
2247
2315
  # projecting to logits
2248
2316
 
2249
2317
  if not return_embeddings:
@@ -2252,6 +2320,14 @@ class TransformerWrapper(Module):
2252
2320
  else:
2253
2321
  logits = self.to_logits(x)
2254
2322
 
2323
+ # handle maybe combine mixture
2324
+
2325
+ if exists(combine_mixture):
2326
+ with autocast('cuda', enabled = False):
2327
+ prob = logits.softmax(dim = -1)
2328
+ mos = einsum('... k d, ... k -> ... d', prob, combine_mixture)
2329
+ logits = log(mos)
2330
+
2255
2331
  # maybe squeeze out last dimension of logits
2256
2332
 
2257
2333
  if self.squeeze_out_last_dim:
@@ -2272,14 +2348,14 @@ class TransformerWrapper(Module):
2272
2348
  # aux loss
2273
2349
 
2274
2350
  if return_attn_z_loss:
2275
- pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
2351
+ pre_softmax_attns = [t.pre_softmax_attn for t in intermediates.attn_intermediates]
2276
2352
  intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
2277
2353
  return_intermediates = True
2278
2354
 
2279
2355
  if return_mems:
2280
2356
  hiddens = intermediates.hiddens
2281
- new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
2282
- new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
2357
+ new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
2358
+ new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]
2283
2359
 
2284
2360
  if not return_intermediates:
2285
2361
  return out, new_mems
@@ -2290,7 +2366,7 @@ class TransformerWrapper(Module):
2290
2366
  return out, intermediates
2291
2367
 
2292
2368
  if return_attn:
2293
- attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
2369
+ attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
2294
2370
  return out, attn_maps
2295
2371
 
2296
2372
  return out
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.35.3
3
+ Version: 1.37.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
2
  x_transformers/attend.py,sha256=7q996VGYHGIsc0FQnN8WNiwHn3xny3i1biRwx7yW5vg,12090
3
- x_transformers/autoregressive_wrapper.py,sha256=pDymmnPgWQoH7wwHKskI_gktsdQX-LysnQtIozodYGU,10422
3
+ x_transformers/autoregressive_wrapper.py,sha256=2FN4ZobFcdDGDGWEnUof_geb16dRGSJycZGwG899Pa4,10493
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
8
- x_transformers/x_transformers.py,sha256=ma5_LbZf5UvfKYJUJcqceUdFG8THFVzER9ZrDXKVV7Y,80780
8
+ x_transformers/x_transformers.py,sha256=ztP6nNncVoPONR-al5lHIphAJQqNcE0mrT6tFWsnyPk,83281
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.35.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.35.3.dist-info/METADATA,sha256=YEiRJvu5g17ZVT3saNBhrmpNeRLqPXyN0cBdajt3psM,661
13
- x_transformers-1.35.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.35.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.35.3.dist-info/RECORD,,
11
+ x_transformers-1.37.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.37.0.dist-info/METADATA,sha256=S8fQ4scePXn4pMl1_01cyWU8_3UXXBlLczibRSFuOoM,661
13
+ x_transformers-1.37.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.37.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.37.0.dist-info/RECORD,,