x-transformers 1.37.8__py3-none-any.whl → 1.37.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -36,6 +36,9 @@ def exists(val):
36
36
  def default(val, d):
37
37
  return val if exists(val) else d
38
38
 
39
+ def at_most_one_of(*bools):
40
+ return sum([*map(int, bools)]) <= 1
41
+
39
42
  def compact(arr):
40
43
  return [*filter(exists, arr)]
41
44
 
@@ -100,6 +103,7 @@ class Attend(Module):
100
103
  scale = None,
101
104
  qk_norm = False,
102
105
  l2_distance = False,
106
+ sigmoid = False,
103
107
  flash = False,
104
108
  softclamp_logits = False,
105
109
  logit_softclamp_value = 50.,
@@ -116,10 +120,25 @@ class Attend(Module):
116
120
  super().__init__()
117
121
  self.scale = scale
118
122
 
123
+ # causal related
124
+
119
125
  self.causal = causal
120
126
  self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
121
127
 
122
- self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax
128
+ # attention type
129
+
130
+ assert not (flash and sigmoid), 'sigmoid attention not available for flash'
131
+ assert at_most_one_of(sigmoid, l2_distance)
132
+
133
+ self.sigmoid = sigmoid
134
+
135
+ if not sigmoid:
136
+ softmax_fn = partial(F.softmax, dim = -1)
137
+ self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
138
+ else:
139
+ self.attn_fn = F.sigmoid
140
+
141
+ # dropouts
123
142
 
124
143
  self.dropout = dropout
125
144
  self.attn_dropout = nn.Dropout(dropout)
@@ -410,7 +429,7 @@ class Attend(Module):
410
429
  if self.sigsoftmax:
411
430
  sim = sim + sim.sigmoid().log()
412
431
 
413
- attn = self.attn_fn(sim, dim = -1)
432
+ attn = self.attn_fn(sim)
414
433
  attn = attn.type(dtype)
415
434
 
416
435
  post_softmax_attn = attn
@@ -924,6 +924,7 @@ class Attention(Module):
924
924
  qk_norm_scale = 10,
925
925
  qk_norm_dim_scale = False,
926
926
  l2_distance = False,
927
+ sigmoid = False,
927
928
  one_kv_head = False,
928
929
  kv_heads = None,
929
930
  shared_kv = False,
@@ -1039,6 +1040,7 @@ class Attention(Module):
1039
1040
  qk_norm = qk_norm,
1040
1041
  scale = qk_norm_scale if qk_norm else self.scale,
1041
1042
  l2_distance = l2_distance,
1043
+ sigmoid = sigmoid,
1042
1044
  add_zero_kv = add_zero_kv,
1043
1045
  flash = flash,
1044
1046
  softclamp_logits = softclamp_logits,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.37.8
3
+ Version: 1.37.9
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=we7mkwVCD7_ColUD8_Fj0HM5jjOaa3wbstllp_XXK4k,13434
2
+ x_transformers/attend.py,sha256=EtaTN6ahgRlFLkwfHA31RNL_bQyAHwhNBpGU1NIHJ-c,13894
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=-2fj6QcDSfMI5lJA_fzOW2mdzdS1C1LD6jMBtGQY48E,83752
8
+ x_transformers/x_transformers.py,sha256=Y77_ZPWSKTJ-oYk4bHjhwMEkgoMaq_LyxcmCkvOPZ9g,83808
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.37.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.37.8.dist-info/METADATA,sha256=fiT94VbrxWL-8jJBjxvFloWsH6n6reOGitRSlpAhvWs,661
13
- x_transformers-1.37.8.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.37.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.37.8.dist-info/RECORD,,
11
+ x_transformers-1.37.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.37.9.dist-info/METADATA,sha256=9JF40JYlW1y_AOqdD1pwYZJpJWZn63SC0K0VX8IA2JU,661
13
+ x_transformers-1.37.9.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.37.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.37.9.dist-info/RECORD,,