x-transformers 1.29.2__py3-none-any.whl → 1.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -1,5 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  from functools import partial
2
- from typing import Optional, Tuple
4
+ from typing import Tuple
3
5
 
4
6
  import torch
5
7
  from torch import nn, einsum, Tensor
@@ -16,10 +18,10 @@ from einops import rearrange, repeat
16
18
 
17
19
  @dataclass
18
20
  class Intermediates:
19
- qk_similarities: Optional[Tensor] = None
20
- pre_softmax_attn: Optional[Tensor] = None
21
- post_softmax_attn: Optional[Tensor] = None
22
- cached_kv: Optional[Tuple[Tensor, Tensor]] = None
21
+ qk_similarities: Tensor | None = None
22
+ pre_softmax_attn: Tensor | None = None
23
+ post_softmax_attn: Tensor | None = None
24
+ cached_kv: Tuple[Tensor, Tensor] | None = None
23
25
 
24
26
  def to_tuple(self):
25
27
  return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
@@ -1,5 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  from math import ceil, log
2
- from typing import Optional, Union, Tuple, Callable
4
+ from typing import Tuple, Callable
3
5
 
4
6
  import torch
5
7
  from torch import nn, Tensor
@@ -133,12 +135,12 @@ class AutoregressiveWrapper(Module):
133
135
  seq_len,
134
136
  eos_token = None,
135
137
  temperature = 1.,
136
- prompt_lens: Optional[Tensor] = None,
138
+ prompt_lens: Tensor | None = None,
137
139
  filter_logits_fn: Callable = top_k,
138
140
  restrict_to_max_seq_len = True,
139
- amateur_model: Optional[Union[Module, Tuple[Module]]] = None,
141
+ amateur_model: Module | Tuple[Module] | None = None,
140
142
  filter_kwargs: dict = dict(),
141
- contrastive_decode_kwargs: Union[dict, Tuple[dict]] = dict(
143
+ contrastive_decode_kwargs: dict | Tuple[dict] = dict(
142
144
  beta = 0.5,
143
145
  alpha = 0.1
144
146
  ),
@@ -143,11 +143,11 @@ class ContinuousTransformerWrapper(nn.Module):
143
143
 
144
144
  if return_mems:
145
145
  hiddens = intermediates.hiddens
146
- new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
146
+ new_mems = tuple(t[..., -self.max_mem_len:, :].detach() for t in hiddens)
147
147
  return out, new_mems
148
148
 
149
149
  if return_attn:
150
- attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
150
+ attn_maps = tuple(t.post_softmax_attn for t in intermediates.attn_intermediates)
151
151
  return out, attn_maps
152
152
 
153
153
  return out
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import math
2
4
  from random import random
3
5
  from packaging import version
@@ -11,7 +13,7 @@ from torch.cuda.amp import autocast
11
13
  from functools import partial, wraps
12
14
  from collections import namedtuple
13
15
  from dataclasses import dataclass
14
- from typing import List, Dict, Tuple, Callable, Optional, Union
16
+ from typing import List, Dict, Tuple, Callable
15
17
 
16
18
  from einops import rearrange, repeat, reduce, pack, unpack
17
19
  from einops.layers.torch import Rearrange
@@ -25,13 +27,13 @@ DEFAULT_DIM_HEAD = 64
25
27
 
26
28
  @dataclass
27
29
  class LayerIntermediates:
28
- hiddens: Optional[List[Tensor]] = None # all hiddens, before the final norm (in pre-norm architecture)
29
- last_hidden: Optional[Tensor] = None # very last hidden after all attention layers, after the final norm
30
- attn_intermediates: Optional[List[Intermediates]] = None
31
- layer_hiddens: Optional[List[Tensor]] = None
32
- attn_z_loss: Optional[Tensor] = None
33
- mems: Optional[Tensor] = None
34
- memory_tokens: Optional[Tensor] = None
30
+ hiddens: List[Tensor] | None = None # all hiddens, before the final norm (in pre-norm architecture)
31
+ last_hidden: Tensor | None = None # very last hidden after all attention layers, after the final norm
32
+ attn_intermediates: List[Intermediates] | None = None
33
+ layer_hiddens: List[Tensor] | None = None
34
+ attn_z_loss: Tensor | None = None
35
+ mems: Tensor | None = None
36
+ memory_tokens: Tensor | None = None
35
37
 
36
38
  # helpers
37
39
 
@@ -140,7 +142,7 @@ def init_zero_(layer):
140
142
  # keyword argument helpers
141
143
 
142
144
  def pick_and_pop(keys, d):
143
- values = list(map(lambda key: d.pop(key), keys))
145
+ values = tuple(d.pop(key) for key in keys)
144
146
  return dict(zip(keys, values))
145
147
 
146
148
  def group_dict_by_key(cond, d):
@@ -149,7 +151,7 @@ def group_dict_by_key(cond, d):
149
151
  match = bool(cond(key))
150
152
  ind = int(not match)
151
153
  return_val[ind][key] = d[key]
152
- return (*return_val,)
154
+ return tuple(return_val)
153
155
 
154
156
  def string_begins_with(prefix, str):
155
157
  return str.startswith(prefix)
@@ -159,7 +161,8 @@ def group_by_key_prefix(prefix, d):
159
161
 
160
162
  def groupby_prefix_and_trim(prefix, d):
161
163
  kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
162
- kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
164
+ prefix_len = len(prefix)
165
+ kwargs_without_prefix = {key[prefix_len:]: value for key, value in kwargs_with_prefix.items()}
163
166
  return kwargs_without_prefix, kwargs
164
167
 
165
168
  # structured dropout, more effective than traditional attention dropouts
@@ -441,25 +444,27 @@ class RotaryEmbedding(Module):
441
444
 
442
445
  @autocast(enabled = False)
443
446
  def forward(self, t):
444
- max_pos = t.max()+1
447
+ max_pos = t.max() + 1
445
448
 
446
449
  freqs = torch.einsum('i , j -> i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
447
- freqs = torch.cat((freqs, freqs), dim = -1)
450
+ freqs = torch.stack((freqs, freqs), dim = -1)
451
+ freqs = rearrange(freqs, '... d r -> ... (d r)')
448
452
 
449
453
  if not exists(self.scale):
450
454
  return freqs, 1.
451
455
 
452
456
  power = (t - (max_pos // 2)) / self.scale_base
453
457
  scale = self.scale ** rearrange(power, 'n -> n 1')
454
- scale = torch.cat((scale, scale), dim = -1)
458
+ scale = torch.stack((scale, scale), dim = -1)
459
+ scale = rearrange(scale, '... d r -> ... (d r)')
455
460
 
456
461
  return freqs, scale
457
462
 
458
-
459
463
  def rotate_half(x):
460
- x = rearrange(x, '... (j d) -> ... j d', j = 2)
461
- x1, x2 = x.unbind(dim = -2)
462
- return torch.cat((-x2, x1), dim = -1)
464
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
465
+ x1, x2 = x.unbind(dim = -1)
466
+ x = torch.stack((-x2, x1), dim = -1)
467
+ return rearrange(x, '... d r -> ... (d r)')
463
468
 
464
469
  @autocast(enabled = False)
465
470
  def apply_rotary_pos_emb(t, freqs, scale = 1):
@@ -572,8 +577,8 @@ class GRUGating(Module):
572
577
  def shift(t, amount, mask = None):
573
578
  if amount == 0:
574
579
  return t
575
- else:
576
- amount = min(amount, t.shape[1])
580
+
581
+ amount = min(amount, t.shape[1])
577
582
 
578
583
  if exists(mask):
579
584
  t = t.masked_fill(~mask[..., None], 0.)
@@ -597,6 +602,23 @@ class ShiftTokens(Module):
597
602
  x = torch.cat((*segments_to_shift, *rest), dim = -1)
598
603
  return self.fn(x, **kwargs)
599
604
 
605
+ # post branch operator
606
+
607
+ class LayerScale(Module):
608
+ def __init__(self, fn: Module, dim, init_value = 0.):
609
+ super().__init__()
610
+ self.fn = fn
611
+ self.gamma = nn.Parameter(torch.ones(dim) * init_value)
612
+
613
+ def forward(self, x, **kwargs):
614
+ out = self.fn(x, **kwargs)
615
+
616
+ if isinstance(out, Tensor):
617
+ return out * self.gamma
618
+
619
+ out, *rest = out
620
+ return out * self.gamma, *rest
621
+
600
622
  # feedforward
601
623
 
602
624
  class GLU(Module):
@@ -817,7 +839,7 @@ class Attention(Module):
817
839
  mem = None,
818
840
  mem_mask = None,
819
841
  return_intermediates = False,
820
- cache: Optional[Intermediates] = None,
842
+ cache: Intermediates | None = None,
821
843
  ):
822
844
  b, n, h, kv_h, head_scale, num_mem_kv, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context)
823
845
 
@@ -1024,11 +1046,11 @@ class AttentionLayers(Module):
1024
1046
  rotary_interpolation_factor = 1.,
1025
1047
  rotary_xpos_scale_base = 512,
1026
1048
  rotary_base_rescale_factor = 1.,
1027
- custom_layers = None,
1049
+ weight_tie_layers = False,
1050
+ custom_layers: Tuple[str] | None = None,
1051
+ layers_execute_order: Tuple[int] | None = None,
1028
1052
  sandwich_coef = None,
1029
1053
  par_ratio = None,
1030
- weight_tie_layers = False, # Albert - https://arxiv.org/abs/1909.11942
1031
- layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
1032
1054
  residual_attn = False,
1033
1055
  cross_residual_attn = False,
1034
1056
  macaron = False,
@@ -1045,6 +1067,8 @@ class AttentionLayers(Module):
1045
1067
  layer_dropout = 0.,
1046
1068
  cross_attn_tokens_dropout = 0.,
1047
1069
  disable_abs_pos_emb = None,
1070
+ use_layerscale = False,
1071
+ layerscale_init_value = 0.,
1048
1072
  **kwargs
1049
1073
  ):
1050
1074
  super().__init__()
@@ -1108,6 +1132,8 @@ class AttentionLayers(Module):
1108
1132
 
1109
1133
  self.cross_attend = cross_attend
1110
1134
 
1135
+ # determine norm
1136
+
1111
1137
  assert (int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm)) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
1112
1138
 
1113
1139
  if use_scalenorm:
@@ -1121,6 +1147,8 @@ class AttentionLayers(Module):
1121
1147
 
1122
1148
  norm_fn = partial(norm_class, dim)
1123
1149
 
1150
+ # determine default block layer type order
1151
+
1124
1152
  if cross_attend and not only_cross:
1125
1153
  default_block = ('a', 'c', 'f')
1126
1154
  elif cross_attend and only_cross:
@@ -1131,6 +1159,13 @@ class AttentionLayers(Module):
1131
1159
  if macaron:
1132
1160
  default_block = ('f',) + default_block
1133
1161
 
1162
+ # determine post branch wrapper
1163
+
1164
+ post_branch_fn = None
1165
+
1166
+ if use_layerscale:
1167
+ post_branch_fn = partial(LayerScale, dim = dim, init_value = layerscale_init_value)
1168
+
1134
1169
  # zero init
1135
1170
 
1136
1171
  if zero_init_branch_output:
@@ -1219,6 +1254,9 @@ class AttentionLayers(Module):
1219
1254
  shift_range_lower = -layer_shift_tokens if not causal else 0
1220
1255
  layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
1221
1256
 
1257
+ if exists(post_branch_fn):
1258
+ layer = post_branch_fn(layer)
1259
+
1222
1260
  residual_fn = GRUGating if gate_residual else Residual
1223
1261
  residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1224
1262
 
@@ -1248,8 +1286,8 @@ class AttentionLayers(Module):
1248
1286
  self_attn_kv_mask = None,
1249
1287
  mems = None,
1250
1288
  mem_masks = None,
1251
- seq_start_pos: Optional[Tensor] = None,
1252
- cache: Optional[LayerIntermediates] = None,
1289
+ seq_start_pos: Tensor | None = None,
1290
+ cache: LayerIntermediates | None = None,
1253
1291
  cache_age = 1,
1254
1292
  return_hiddens = False,
1255
1293
  rotary_pos_emb = None
@@ -1641,7 +1679,7 @@ class TransformerWrapper(Module):
1641
1679
  return_attn_z_loss = False,
1642
1680
  attn_z_loss_weight = 1e-4,
1643
1681
  seq_start_pos = None,
1644
- cache: Optional[LayerIntermediates] = None,
1682
+ cache: LayerIntermediates | None = None,
1645
1683
  **kwargs
1646
1684
  ):
1647
1685
  b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
x_transformers/xval.py CHANGED
@@ -176,11 +176,11 @@ class XValTransformerWrapper(nn.Module):
176
176
 
177
177
  if return_mems:
178
178
  hiddens = intermediates.hiddens
179
- new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
179
+ new_mems = tuple(t[..., -self.max_mem_len:, :].detach() for t in hiddens)
180
180
  return out, new_mems
181
181
 
182
182
  if return_attn:
183
- attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
183
+ attn_maps = tuple(t.post_softmax_attn for t in intermediates.attn_intermediates)
184
184
  return out, attn_maps
185
185
 
186
186
  return out
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.29.2
3
+ Version: 1.30.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -14,6 +14,6 @@ Classifier: License :: OSI Approved :: MIT License
14
14
  Classifier: Programming Language :: Python :: 3.6
15
15
  Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
- Requires-Dist: torch >=1.6
18
- Requires-Dist: einops >=0.7.0
17
+ Requires-Dist: torch >=2.0
18
+ Requires-Dist: einops >=0.8.0
19
19
 
@@ -0,0 +1,14 @@
1
+ x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
2
+ x_transformers/attend.py,sha256=Y9eE26I7BM8rGveabhiRhzw_xq9TY61Sp10QC1hX2O8,10192
3
+ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
+ x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
+ x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
+ x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
+ x_transformers/x_transformers.py,sha256=EEfqwI-NANzrQf10Tc_bRSdjWOIEJdhxOfzeKY4osyI,66137
8
+ x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
+ x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
+ x_transformers-1.30.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.1.dist-info/METADATA,sha256=gkmRLAvk0l9_vkrTVBIWLnFq_cEtCYrI8oI3B07d9B8,661
12
+ x_transformers-1.30.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.1.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
2
- x_transformers/attend.py,sha256=L7vctHJ0PnECohu4cUu8yvY8cUrVyJxHmMFR0RGL0z4,10163
3
- x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
4
- x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
5
- x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
- x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=vPt5x0Pg03xGf8t2rZGW0zPd8xP0uvGLQvROFlmmOao,65200
8
- x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
- x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
10
- x_transformers-1.29.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.29.2.dist-info/METADATA,sha256=0_ON52HHs50Dcwp4PMfGhLWhDGKC9Rd4V3QAvmqxGyo,661
12
- x_transformers-1.29.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.29.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.29.2.dist-info/RECORD,,