rxnn 0.1.33__py3-none-any.whl → 0.1.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/transformers/attention.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
import torch.nn.functional as F
|
4
|
-
from torch.
|
4
|
+
from torch.nn.attention import sdpa_kernel, SDPBackend
|
5
5
|
import math
|
6
6
|
from .positional import RotaryPositionalEmbedding, RelativePositionalEmbedding
|
7
7
|
|
@@ -102,7 +102,7 @@ class MultiHeadAttention(nn.Module):
|
|
102
102
|
|
103
103
|
def _flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
104
104
|
mask: torch.Tensor = None, enable_gqa: bool = False):
|
105
|
-
with
|
105
|
+
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
106
106
|
attn_output = F.scaled_dot_product_attention(
|
107
107
|
q, k, v,
|
108
108
|
attn_mask=mask if not self.is_causal else None,
|
@@ -16,7 +16,7 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
17
17
|
rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
rxnn/transformers/attention.py,sha256=
|
19
|
+
rxnn/transformers/attention.py,sha256=J5tK5mxWHEhgl9js_DRz1f7S2uaCHCnFJaX_KIi1aAQ,14332
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
21
|
rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.34.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.34.dist-info/METADATA,sha256=VZOZ6XCTGWy9vEXZaxUC49-gFeidnY6urLHnbT3HRbo,16627
|
30
|
+
rxnn-0.1.34.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.34.dist-info/RECORD,,
|
File without changes
|
File without changes
|