rxnn 0.1.32__py3-none-any.whl → 0.1.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/transformers/attention.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
import torch.nn.functional as F
|
4
|
+
from torch.backends.cuda import sdp_kernel, SDPBackend
|
4
5
|
import math
|
5
6
|
from .positional import RotaryPositionalEmbedding, RelativePositionalEmbedding
|
6
7
|
|
@@ -17,7 +18,7 @@ class MultiHeadAttention(nn.Module):
|
|
17
18
|
rope_only_for_query: bool = False,
|
18
19
|
use_relative_embeddings: bool = False,
|
19
20
|
max_seq_len: int = 1024,
|
20
|
-
use_flash_attention: bool =
|
21
|
+
use_flash_attention: bool = True,
|
21
22
|
is_causal: bool = False,
|
22
23
|
use_bias: bool = False,
|
23
24
|
*args,
|
@@ -101,13 +102,14 @@ class MultiHeadAttention(nn.Module):
|
|
101
102
|
|
102
103
|
def _flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
103
104
|
mask: torch.Tensor = None, enable_gqa: bool = False):
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
105
|
+
with sdp_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
106
|
+
attn_output = F.scaled_dot_product_attention(
|
107
|
+
q, k, v,
|
108
|
+
attn_mask=mask if not self.is_causal else None,
|
109
|
+
dropout_p=self.dropout.p if self.training else 0.0,
|
110
|
+
is_causal=self.is_causal,
|
111
|
+
enable_gqa=enable_gqa,
|
112
|
+
)
|
111
113
|
return self._transpose_output(attn_output, b, t, d)
|
112
114
|
|
113
115
|
def _calculate_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
@@ -16,7 +16,7 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
17
17
|
rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
rxnn/transformers/attention.py,sha256=
|
19
|
+
rxnn/transformers/attention.py,sha256=Nox986BH9qq4rDYLiYmfj1DeMeULF3akexIl99MPccM,14331
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
21
|
rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.33.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.33.dist-info/METADATA,sha256=m3DWDnTu7Lx1kHYPIAQCdKU8t4QZBdqG0QcSIFvB924,16627
|
30
|
+
rxnn-0.1.33.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.33.dist-info/RECORD,,
|
File without changes
|
File without changes
|