x-transformers 1.37.3__py3-none-any.whl → 1.37.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -13,7 +13,7 @@ from functools import wraps
13
13
  from packaging import version
14
14
  from dataclasses import dataclass
15
15
 
16
- from einops import rearrange, repeat
16
+ from einops import rearrange, repeat, pack, unpack
17
17
 
18
18
  # constants
19
19
 
@@ -39,9 +39,16 @@ def default(val, d):
39
39
  def compact(arr):
40
40
  return [*filter(exists, arr)]
41
41
 
42
- def softclamp(t, value):
42
+ @torch.jit.script
43
+ def softclamp(t: Tensor, value: float):
43
44
  return (t / value).tanh() * value
44
45
 
46
+ def pack_one(t, pattern):
47
+ return pack([t], pattern)
48
+
49
+ def unpack_one(t, ps, pattern):
50
+ return unpack(t, ps, pattern)[0]
51
+
45
52
  def once(fn):
46
53
  called = False
47
54
  @wraps(fn)
@@ -55,6 +62,18 @@ def once(fn):
55
62
 
56
63
  print_once = once(print)
57
64
 
65
+ # alternative distance functions
66
+
67
+ def qk_l2_distance(q, k):
68
+ if k.ndim == 3:
69
+ k = repeat(k, 'b j d -> b h j d', h = q.shape[1])
70
+
71
+ q, packed_shape = pack_one(q, '* i d')
72
+ k, _ = pack_one(k, '* j d')
73
+
74
+ distance = torch.cdist(q, k)
75
+ return unpack_one(distance, packed_shape, '* i j')
76
+
58
77
  # functions for creating causal mask
59
78
  # need a special one for onnx cpu (no support for .triu)
60
79
 
@@ -80,6 +99,7 @@ class Attend(Module):
80
99
  sparse_topk = None,
81
100
  scale = None,
82
101
  qk_norm = False,
102
+ l2_distance = False,
83
103
  flash = False,
84
104
  softclamp_logits = False,
85
105
  logit_softclamp_value = 50.,
@@ -123,6 +143,11 @@ class Attend(Module):
123
143
  assert not (flash and sigsoftmax), 'sigsoftmax not available for flash attention'
124
144
  self.sigsoftmax = sigsoftmax
125
145
 
146
+ # l2 distance attention
147
+
148
+ assert not (flash and l2_distance), 'l2 distance attention does not work with flash attention just yet'
149
+ self.l2_distance = l2_distance
150
+
126
151
  # add a key / value token composed of zeros
127
152
  # in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html
128
153
 
@@ -325,7 +350,12 @@ class Attend(Module):
325
350
 
326
351
  kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
327
352
 
328
- sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
353
+ if not self.l2_distance:
354
+ sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k)
355
+ else:
356
+ sim = -qk_l2_distance(q, k)
357
+
358
+ sim = sim * scale
329
359
 
330
360
  if exists(prev_attn):
331
361
  sim = sim + prev_attn
@@ -317,7 +317,7 @@ class AutoregressiveWrapper(Module):
317
317
  **kwargs
318
318
  )
319
319
 
320
- loss_fn = F.cross_entropy if not self.net.is_log_prob else F.nll_loss
320
+ loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
321
321
 
322
322
  loss = loss_fn(
323
323
  rearrange(logits, 'b n c -> b c n'),
@@ -309,9 +309,11 @@ class NonAutoregressiveWrapper(nn.Module):
309
309
  with context():
310
310
  logits = self.net(masked, **kwargs)
311
311
 
312
+ loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
313
+
312
314
  # cross entropy loss
313
315
 
314
- loss = F.cross_entropy(
316
+ loss = loss_fn(
315
317
  logits[mask],
316
318
  orig_seq[mask]
317
319
  )
@@ -923,6 +923,7 @@ class Attention(Module):
923
923
  qk_norm_groups = 1,
924
924
  qk_norm_scale = 10,
925
925
  qk_norm_dim_scale = False,
926
+ l2_distance = False,
926
927
  one_kv_head = False,
927
928
  kv_heads = None,
928
929
  shared_kv = False,
@@ -1037,6 +1038,7 @@ class Attention(Module):
1037
1038
  sparse_topk = sparse_topk,
1038
1039
  qk_norm = qk_norm,
1039
1040
  scale = qk_norm_scale if qk_norm else self.scale,
1041
+ l2_distance = l2_distance,
1040
1042
  add_zero_kv = add_zero_kv,
1041
1043
  flash = flash,
1042
1044
  softclamp_logits = softclamp_logits,
@@ -2078,7 +2080,7 @@ class TransformerWrapper(Module):
2078
2080
 
2079
2081
  # output type
2080
2082
 
2081
- self.is_log_prob = mixture_of_softmax
2083
+ self.output_is_log_prob = mixture_of_softmax
2082
2084
 
2083
2085
  self.to_mixture = None
2084
2086
  self.combine_mixture = None
@@ -40,7 +40,7 @@ class XLAutoregressiveWrapper(nn.Module):
40
40
  eos_token = None,
41
41
  temperature = 1.,
42
42
  filter_logits_fn = top_k,
43
- filter_thres = 0.9,
43
+ filter_kwargs: dict = dict(),
44
44
  mems = None,
45
45
  **kwargs
46
46
  ):
@@ -88,7 +88,7 @@ class XLAutoregressiveWrapper(nn.Module):
88
88
  mems = cache.mems
89
89
 
90
90
  logits = logits[:, -1]
91
- filtered_logits = filter_logits_fn(logits, thres = filter_thres)
91
+ filtered_logits = filter_logits_fn(logits, **filter_kwargs)
92
92
  probs = F.softmax(filtered_logits / temperature, dim=-1)
93
93
 
94
94
  sample = torch.multinomial(probs, 1)
@@ -131,7 +131,9 @@ class XLAutoregressiveWrapper(nn.Module):
131
131
 
132
132
  split_x = x.split(max_seq_len, dim = -1)
133
133
  split_labels = labels.split(max_seq_len, dim = -1)
134
- loss_weights = tuple(map(lambda t: t.shape[-1] / seq_len, split_x))
134
+ loss_weights = tuple((t.shape[-1] / seq_len) for t in split_x)
135
+
136
+ loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
135
137
 
136
138
  # go through each chunk and derive weighted losses
137
139
 
@@ -146,7 +148,7 @@ class XLAutoregressiveWrapper(nn.Module):
146
148
  **kwargs
147
149
  )
148
150
 
149
- loss = F.cross_entropy(
151
+ loss = loss_fn(
150
152
  rearrange(logits, 'b n c -> b c n'),
151
153
  chunk_labels,
152
154
  ignore_index = ignore_index
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.37.3
3
+ Version: 1.37.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -0,0 +1,15 @@
1
+ x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
+ x_transformers/attend.py,sha256=4RnX1yhWZIf8holucqnYXTIP7U1m40UpP58RZNT_2sM,13128
3
+ x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
+ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
+ x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
+ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
+ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
+ x_transformers/x_transformers.py,sha256=-2fj6QcDSfMI5lJA_fzOW2mdzdS1C1LD6jMBtGQY48E,83752
9
+ x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
+ x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
+ x_transformers-1.37.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.37.5.dist-info/METADATA,sha256=zHUhvP1bQjFbMtxnVO9iDESgXpGOQxuBCsm4b6K1w44,661
13
+ x_transformers-1.37.5.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.37.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.37.5.dist-info/RECORD,,
@@ -1,15 +0,0 @@
1
- x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=mV7duZ7ON2puS3-k4ctBifb2rq-jTJqrMbof7tI5jR4,12326
3
- x_transformers/autoregressive_wrapper.py,sha256=2FN4ZobFcdDGDGWEnUof_geb16dRGSJycZGwG899Pa4,10493
4
- x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
- x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
- x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
- x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
8
- x_transformers/x_transformers.py,sha256=gOJBZzOJMu5RkIsxw9TZtde4Sx--D18yX8LjrYIsPbE,83677
9
- x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
- x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.37.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.37.3.dist-info/METADATA,sha256=SIGTCQMrLkyq_aksJAst0iXw9VfFT6QWlGvtUElbTMg,661
13
- x_transformers-1.37.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.37.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.37.3.dist-info/RECORD,,