rxnn 0.1.36__py3-none-any.whl → 0.1.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/transformers/attention.py
CHANGED
@@ -102,8 +102,12 @@ class MultiHeadAttention(nn.Module):
|
|
102
102
|
|
103
103
|
def _flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
104
104
|
mask: torch.Tensor = None, enable_gqa: bool = False):
|
105
|
-
|
106
|
-
|
105
|
+
# After ~6h of fighthing, PyTorch based is still now working so I decided to use FlashAttention directly
|
106
|
+
# with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
107
|
+
# return self._torch_attention(q, k, v, b, t, d, mask=mask, enable_gqa=enable_gqa)
|
108
|
+
from flash_attn import flash_attn_func
|
109
|
+
attn_output = flash_attn_func(q, k, v, dropout_p=self.dropout.p if self.training else 0.0, is_causal=self.is_causal)
|
110
|
+
return self._transpose_output(attn_output, b, t, d)
|
107
111
|
|
108
112
|
def _torch_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
109
113
|
mask: torch.Tensor = None, enable_gqa: bool = False):
|
@@ -16,7 +16,7 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
17
17
|
rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
rxnn/transformers/attention.py,sha256=
|
19
|
+
rxnn/transformers/attention.py,sha256=bsuAXCKR0WbOxgu-IkJHgn7jUu2CK4hqNw60IZbGTEE,15698
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
21
|
rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.37.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.37.dist-info/METADATA,sha256=9cKmCtODY8tw_VimbuCN9787asAvNbNylZBK2gOBzLE,16627
|
30
|
+
rxnn-0.1.37.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.37.dist-info/RECORD,,
|
File without changes
|
File without changes
|