x-transformers 1.37.9__py3-none-any.whl → 1.37.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from functools import partial
4
- from typing import Tuple
4
+ from typing import Tuple, Callable
5
5
 
6
6
  import torch
7
7
  from torch.nn import Module
@@ -104,6 +104,7 @@ class Attend(Module):
104
104
  qk_norm = False,
105
105
  l2_distance = False,
106
106
  sigmoid = False,
107
+ custom_attn_fn: Callable | None = None,
107
108
  flash = False,
108
109
  softclamp_logits = False,
109
110
  logit_softclamp_value = 50.,
@@ -132,7 +133,9 @@ class Attend(Module):
132
133
 
133
134
  self.sigmoid = sigmoid
134
135
 
135
- if not sigmoid:
136
+ if exists(custom_attn_fn):
137
+ self.attn_fn = custom_attn_fn
138
+ elif not sigmoid:
136
139
  softmax_fn = partial(F.softmax, dim = -1)
137
140
  self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
138
141
  else:
@@ -404,11 +407,6 @@ class Attend(Module):
404
407
 
405
408
  mask_value = -torch.finfo(sim.dtype).max
406
409
 
407
- if exists(self.sparse_topk) and self.sparse_topk < j:
408
- top_values, _ = sim.topk(self.sparse_topk, dim = -1)
409
- sparse_topk_mask = sim < top_values[..., -1:]
410
- mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask
411
-
412
410
  if exists(mask):
413
411
  sim = sim.masked_fill(~mask, mask_value)
414
412
 
@@ -416,6 +414,11 @@ class Attend(Module):
416
414
  causal_mask = self.create_causal_mask(i, j, device = device)
417
415
  sim = sim.masked_fill(causal_mask, mask_value)
418
416
 
417
+ if exists(self.sparse_topk):
418
+ top_values, _ = sim.topk(self.sparse_topk, dim = -1)
419
+ sparse_topk_mask = (sim >= top_values[..., -1:]) & (sim > mask_value)
420
+ sim = sim.masked_fill(~sparse_topk_mask, mask_value)
421
+
419
422
  row_is_entirely_masked = None
420
423
 
421
424
  if exists(mask):
@@ -925,6 +925,7 @@ class Attention(Module):
925
925
  qk_norm_dim_scale = False,
926
926
  l2_distance = False,
927
927
  sigmoid = False,
928
+ custom_attn_fn: Callable | None = None,
928
929
  one_kv_head = False,
929
930
  kv_heads = None,
930
931
  shared_kv = False,
@@ -1041,6 +1042,7 @@ class Attention(Module):
1041
1042
  scale = qk_norm_scale if qk_norm else self.scale,
1042
1043
  l2_distance = l2_distance,
1043
1044
  sigmoid = sigmoid,
1045
+ custom_attn_fn = custom_attn_fn,
1044
1046
  add_zero_kv = add_zero_kv,
1045
1047
  flash = flash,
1046
1048
  softclamp_logits = softclamp_logits,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.37.9
3
+ Version: 1.37.10
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=EtaTN6ahgRlFLkwfHA31RNL_bQyAHwhNBpGU1NIHJ-c,13894
2
+ x_transformers/attend.py,sha256=MJmMLIt0rHx-JNNsc2auUsCjsB-69NewufaRV32ADmA,14012
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=Y77_ZPWSKTJ-oYk4bHjhwMEkgoMaq_LyxcmCkvOPZ9g,83808
8
+ x_transformers/x_transformers.py,sha256=q6I6rvyYUWLgwtKOxPwF12UL1HzcIlauI8YrM8gvZac,83901
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.37.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.37.9.dist-info/METADATA,sha256=9JF40JYlW1y_AOqdD1pwYZJpJWZn63SC0K0VX8IA2JU,661
13
- x_transformers-1.37.9.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.37.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.37.9.dist-info/RECORD,,
11
+ x_transformers-1.37.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.37.10.dist-info/METADATA,sha256=qA7YAj5ZeaesnGamBR-cPSR_0HSwquwPBptaUmi7c3c,662
13
+ x_transformers-1.37.10.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.37.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.37.10.dist-info/RECORD,,