x-transformers 1.36.0__py3-none-any.whl → 1.37.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -317,7 +317,9 @@ class AutoregressiveWrapper(Module):
317
317
  **kwargs
318
318
  )
319
319
 
320
- loss = F.cross_entropy(
320
+ loss_fn = F.cross_entropy if not self.net.is_log_prob else F.nll_loss
321
+
322
+ loss = loss_fn(
321
323
  rearrange(logits, 'b n c -> b c n'),
322
324
  target,
323
325
  ignore_index = ignore_index
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+ from typing import Callable
2
3
 
3
4
  import math
4
5
  from random import random, randrange
@@ -14,7 +15,6 @@ from functools import partial, wraps
14
15
  from collections import namedtuple
15
16
  from contextlib import nullcontext
16
17
  from dataclasses import dataclass
17
- from typing import List, Dict, Tuple, Callable
18
18
 
19
19
  from einops import rearrange, repeat, reduce, pack, unpack
20
20
  from einops.layers.torch import Rearrange
@@ -28,14 +28,16 @@ DEFAULT_DIM_HEAD = 64
28
28
 
29
29
  @dataclass
30
30
  class LayerIntermediates:
31
- hiddens: List[Tensor] | None = None # all hiddens, before the final norm (in pre-norm architecture)
31
+ hiddens: list[Tensor] | None = None # all hiddens, before the final norm (in pre-norm architecture)
32
32
  last_hidden: Tensor | None = None # very last hidden after all attention layers, after the final norm
33
- attn_intermediates: List[Intermediates] | None = None
34
- layer_hiddens: List[Tensor] | None = None
33
+ attn_intermediates: list[Intermediates] | None = None
34
+ layer_hiddens: list[Tensor] | None = None
35
35
  attn_z_loss: Tensor | None = None
36
36
  mems: Tensor | None = None
37
37
  memory_tokens: Tensor | None = None
38
38
 
39
+ LinearNoBias = partial(nn.Linear, bias = False)
40
+
39
41
  # helpers
40
42
 
41
43
  def exists(val):
@@ -92,6 +94,9 @@ def Sequential(*modules):
92
94
 
93
95
  # tensor helpers
94
96
 
97
+ def log(t, eps = 1e-20):
98
+ return t.clamp(min = eps).log()
99
+
95
100
  def max_neg_value(tensor):
96
101
  return -torch.finfo(tensor.dtype).max
97
102
 
@@ -114,7 +119,7 @@ def masked_mean(t, mask = None, dim = 1):
114
119
  den = mask.sum(dim = dim).clamp(min = 1.)
115
120
  return num / den
116
121
 
117
- def pad_at_dim(t, pad: Tuple[int, int], dim = -1, value = 0.):
122
+ def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
118
123
  if pad == (0, 0):
119
124
  return t
120
125
 
@@ -131,7 +136,7 @@ def or_reduce(masks):
131
136
  # auxiliary loss helpers
132
137
 
133
138
  def calc_z_loss(
134
- pre_softmax_attns: List[Tensor],
139
+ pre_softmax_attns: list[Tensor],
135
140
  mask = None,
136
141
  weight = 1.
137
142
  ):
@@ -611,7 +616,7 @@ class AdaptiveLayerNorm(Module):
611
616
  dim_condition = default(dim_condition, dim)
612
617
 
613
618
  self.ln = nn.LayerNorm(dim, elementwise_affine = False)
614
- self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
619
+ self.to_gamma = LinearNoBias(dim_condition, dim)
615
620
  nn.init.zeros_(self.to_gamma.weight)
616
621
 
617
622
  def forward(self, x, *, condition):
@@ -666,7 +671,7 @@ class AdaptiveRMSNorm(Module):
666
671
  self.scale = dim ** 0.5
667
672
  dim_condition = default(dim_condition, dim)
668
673
 
669
- self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
674
+ self.to_gamma = LinearNoBias(dim_condition, dim)
670
675
  nn.init.zeros_(self.to_gamma.weight)
671
676
 
672
677
  def forward(self, x, *, condition):
@@ -749,7 +754,7 @@ class ShiftTokens(Module):
749
754
  feats_per_shift = x.shape[-1] // segments
750
755
  splitted = x.split(feats_per_shift, dim = -1)
751
756
  segments_to_shift, rest = splitted[:segments], splitted[segments:]
752
- segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
757
+ segments_to_shift = [shift(*args, mask = mask) for args in zip(segments_to_shift, shifts)]
753
758
  x = torch.cat((*segments_to_shift, *rest), dim = -1)
754
759
  return self.fn(x, **kwargs)
755
760
 
@@ -817,7 +822,7 @@ class ConcatCombine(Module):
817
822
  def __init__(self, dim, prev_layer_ind):
818
823
  super().__init__()
819
824
  self.prev_layer_ind = prev_layer_ind
820
- self.combine = nn.Linear(dim * 2, dim, bias = False)
825
+ self.combine = LinearNoBias(dim * 2, dim)
821
826
 
822
827
  def forward(self, x, prev_layers: list[Tensor]):
823
828
  skip = prev_layers[self.prev_layer_ind]
@@ -957,17 +962,17 @@ class Attention(Module):
957
962
  v_dim = value_dim_head * kv_heads
958
963
  out_dim = value_dim_head * heads
959
964
 
960
- self.to_q = nn.Linear(dim, q_dim, bias = False)
961
- self.to_k = nn.Linear(dim_kv, k_dim, bias = False)
965
+ self.to_q = LinearNoBias(dim, q_dim)
966
+ self.to_k = LinearNoBias(dim_kv, k_dim)
962
967
 
963
968
  # shared key / values, for further memory savings during inference
964
969
 
965
970
  assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
966
- self.to_v = nn.Linear(dim_kv, v_dim, bias = False) if not shared_kv else None
971
+ self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
967
972
 
968
973
  # relations projection from tp-attention
969
974
 
970
- self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
975
+ self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
971
976
 
972
977
  # add GLU gating for aggregated values, from alphafold2
973
978
 
@@ -1063,7 +1068,7 @@ class Attention(Module):
1063
1068
  # output dimension by default same as input, but can be overridden
1064
1069
 
1065
1070
  dim_out = default(dim_out, dim)
1066
- self.to_out = nn.Sequential(nn.Linear(out_dim, dim_out * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim_out, bias = False)
1071
+ self.to_out = nn.Sequential(LinearNoBias(out_dim, dim_out * 2), nn.GLU()) if on_attn else LinearNoBias(out_dim, dim_out)
1067
1072
 
1068
1073
  # whether to rotate positions into values, for absolute positions in addition to relative
1069
1074
 
@@ -1109,7 +1114,7 @@ class Attention(Module):
1109
1114
 
1110
1115
  q = rearrange(q, 'b n (h d) -> b h n d', h = h)
1111
1116
 
1112
- k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h), (k, v, r))
1117
+ k, v, r = tuple(maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h) for t in (k, v, r))
1113
1118
 
1114
1119
  if exists(cache):
1115
1120
  ck, cv = cache.cached_kv
@@ -1164,12 +1169,12 @@ class Attention(Module):
1164
1169
 
1165
1170
  # i, j determined for relative positional bias, excluding memory key / values
1166
1171
 
1167
- i, j = map(lambda t: t.shape[-2], (q, k))
1172
+ i, j = tuple(t.shape[-2] for t in (q, k))
1168
1173
 
1169
1174
  # maybe append memory key / values
1170
1175
 
1171
1176
  if num_mem_kv > 0:
1172
- mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
1177
+ mem_k, mem_v = tuple(repeat(t, 'h n d -> b h n d', b = b) for t in (self.mem_k, self.mem_v))
1173
1178
 
1174
1179
  if self.qk_norm:
1175
1180
  mem_k = l2norm(mem_k)
@@ -1302,8 +1307,8 @@ class AttentionLayers(Module):
1302
1307
  rotary_xpos_scale_base = 512,
1303
1308
  rotary_base_rescale_factor = 1.,
1304
1309
  weight_tie_layers = False,
1305
- custom_layers: Tuple[str, ...] | None = None,
1306
- layers_execute_order: Tuple[int, ...] | None = None,
1310
+ custom_layers: tuple[str, ...] | None = None,
1311
+ layers_execute_order: tuple[int, ...] | None = None,
1307
1312
  sandwich_coef = None,
1308
1313
  par_ratio = None,
1309
1314
  residual_attn = False,
@@ -1464,7 +1469,7 @@ class AttentionLayers(Module):
1464
1469
 
1465
1470
  if self.need_condition and adaptive_condition_mlp:
1466
1471
  self.adaptive_mlp = nn.Sequential(
1467
- nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
1472
+ LinearNoBias(dim_condition, dim_condition * dim_condition_mult),
1468
1473
  nn.SiLU()
1469
1474
  )
1470
1475
 
@@ -1635,7 +1640,7 @@ class AttentionLayers(Module):
1635
1640
  return_hiddens = False,
1636
1641
  rotary_pos_emb = None,
1637
1642
  condition = None,
1638
- layers_execute_order: Tuple[int, ...] | None = None
1643
+ layers_execute_order: tuple[int, ...] | None = None
1639
1644
  ):
1640
1645
  assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1641
1646
  assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
@@ -1973,7 +1978,7 @@ class TransformerWrapper(Module):
1973
1978
  num_tokens,
1974
1979
  max_seq_len,
1975
1980
  attn_layers: AttentionLayers,
1976
- embed_num_tokens: Dict[str, int] = dict(),
1981
+ embed_num_tokens: dict[str, int] = dict(),
1977
1982
  emb_dim = None,
1978
1983
  max_mem_len = 0,
1979
1984
  shift_mem_down = 0,
@@ -1996,6 +2001,8 @@ class TransformerWrapper(Module):
1996
2001
  use_cls_token = False,
1997
2002
  squeeze_out_last_dim = False,
1998
2003
  token_emb: TokenEmbedding | None = None,
2004
+ mixture_of_softmax = False,
2005
+ mixture_of_softmax_k = 4,
1999
2006
  ):
2000
2007
  super().__init__()
2001
2008
 
@@ -2050,7 +2057,7 @@ class TransformerWrapper(Module):
2050
2057
  # maybe recycling
2051
2058
 
2052
2059
  self.recycling = recycling
2053
- self.recycled_proj = nn.Linear(dim, dim, bias = False) if recycling else None
2060
+ self.recycled_proj = LinearNoBias(dim, dim) if recycling else None
2054
2061
 
2055
2062
  self.train_max_recycle_steps = train_max_recycle_steps
2056
2063
 
@@ -2066,21 +2073,37 @@ class TransformerWrapper(Module):
2066
2073
 
2067
2074
  self.average_pool_embed = average_pool_embed
2068
2075
 
2076
+ # output type
2077
+
2078
+ self.is_log_prob = mixture_of_softmax
2079
+
2080
+ self.to_mixture = None
2081
+ self.combine_mixture = None
2082
+
2083
+ if mixture_of_softmax:
2084
+ assert num_output_heads == 1
2085
+
2086
+ self.to_mixture = Sequential(
2087
+ LinearNoBias(dim, dim * mixture_of_softmax_k),
2088
+ Rearrange('... (k d) -> ... k d', k = mixture_of_softmax_k)
2089
+ )
2090
+
2091
+ self.combine_mixture = LinearNoBias(dim, mixture_of_softmax_k)
2092
+
2069
2093
  # output head, usually to logits of num_tokens
2070
2094
 
2071
2095
  logits_dim = default(logits_dim, num_tokens)
2072
2096
 
2073
- self.has_multiple_heads = False
2097
+ self.has_multiple_heads = num_output_heads > 1
2074
2098
 
2075
2099
  if return_only_embed:
2076
2100
  self.to_logits = None
2077
2101
  elif tie_embedding:
2078
2102
  self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
2079
2103
  elif num_output_heads > 1:
2080
- self.has_multiple_heads = True
2081
- self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
2104
+ self.to_logits = ModuleList([LinearNoBias(dim, logits_dim) for _ in range(num_output_heads)])
2082
2105
  else:
2083
- self.to_logits = nn.Linear(dim, logits_dim, bias = False)
2106
+ self.to_logits = LinearNoBias(dim, logits_dim)
2084
2107
 
2085
2108
  # memory tokens (like [cls]) from Memory Transformers paper
2086
2109
 
@@ -2124,7 +2147,7 @@ class TransformerWrapper(Module):
2124
2147
  pos = None,
2125
2148
  prepend_embeds = None,
2126
2149
  prepend_mask = None,
2127
- embed_ids: Dict[str, Tensor] = dict(),
2150
+ embed_ids: dict[str, Tensor] = dict(),
2128
2151
  sum_embeds = None,
2129
2152
  return_attn_z_loss = False,
2130
2153
  attn_z_loss_weight = 1e-4,
@@ -2235,6 +2258,8 @@ class TransformerWrapper(Module):
2235
2258
  # attention layers
2236
2259
 
2237
2260
  if not self.recycling:
2261
+ assert recycle_steps == 1, 'you did not train with recycling'
2262
+
2238
2263
  # regular
2239
2264
 
2240
2265
  attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
@@ -2281,6 +2306,14 @@ class TransformerWrapper(Module):
2281
2306
  if exists(self.cls_token):
2282
2307
  x, _ = unpack(x, cls_packed_shape, 'b * d')
2283
2308
 
2309
+ # handle expansion to mixture if needed (for mixture of softmax)
2310
+
2311
+ combine_mixture = None
2312
+
2313
+ if exists(self.to_mixture):
2314
+ combine_mixture = self.combine_mixture(x).softmax(dim = -1)
2315
+ x = self.to_mixture(x)
2316
+
2284
2317
  # projecting to logits
2285
2318
 
2286
2319
  if not return_embeddings:
@@ -2289,6 +2322,14 @@ class TransformerWrapper(Module):
2289
2322
  else:
2290
2323
  logits = self.to_logits(x)
2291
2324
 
2325
+ # handle maybe combine mixture
2326
+
2327
+ if exists(combine_mixture):
2328
+ with autocast('cuda', enabled = False):
2329
+ prob = logits.softmax(dim = -1)
2330
+ mos = einsum('... k d, ... k -> ... d', prob, combine_mixture)
2331
+ logits = log(mos)
2332
+
2292
2333
  # maybe squeeze out last dimension of logits
2293
2334
 
2294
2335
  if self.squeeze_out_last_dim:
@@ -2309,14 +2350,14 @@ class TransformerWrapper(Module):
2309
2350
  # aux loss
2310
2351
 
2311
2352
  if return_attn_z_loss:
2312
- pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
2353
+ pre_softmax_attns = [t.pre_softmax_attn for t in intermediates.attn_intermediates]
2313
2354
  intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
2314
2355
  return_intermediates = True
2315
2356
 
2316
2357
  if return_mems:
2317
2358
  hiddens = intermediates.hiddens
2318
- new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
2319
- new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
2359
+ new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
2360
+ new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]
2320
2361
 
2321
2362
  if not return_intermediates:
2322
2363
  return out, new_mems
@@ -2327,7 +2368,7 @@ class TransformerWrapper(Module):
2327
2368
  return out, intermediates
2328
2369
 
2329
2370
  if return_attn:
2330
- attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
2371
+ attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
2331
2372
  return out, attn_maps
2332
2373
 
2333
2374
  return out
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.36.0
3
+ Version: 1.37.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
2
  x_transformers/attend.py,sha256=7q996VGYHGIsc0FQnN8WNiwHn3xny3i1biRwx7yW5vg,12090
3
- x_transformers/autoregressive_wrapper.py,sha256=pDymmnPgWQoH7wwHKskI_gktsdQX-LysnQtIozodYGU,10422
3
+ x_transformers/autoregressive_wrapper.py,sha256=2FN4ZobFcdDGDGWEnUof_geb16dRGSJycZGwG899Pa4,10493
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
8
- x_transformers/x_transformers.py,sha256=iib15Squ9VE7tLpb4Z4_Hq_hi7dZhPNR_xPtC9BzMrE,82321
8
+ x_transformers/x_transformers.py,sha256=9lk6wtz0vNigyLoMWleo442Q0mhce-BCxEhazhSHuvI,83356
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.36.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.36.0.dist-info/METADATA,sha256=YKcnT5T0UkZxwpP72cPfx9RN0SVoBYy0e6Xo581YCE0,661
13
- x_transformers-1.36.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.36.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.36.0.dist-info/RECORD,,
11
+ x_transformers-1.37.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.37.1.dist-info/METADATA,sha256=ik8UKwzq_pW9zdxCl6pt7POrjRC7_GwIi6gAnY7Fck0,661
13
+ x_transformers-1.37.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.37.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.37.1.dist-info/RECORD,,