natten-mps 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- natten_mps/__init__.py +53 -0
- natten_mps/_core/__init__.py +3 -0
- natten_mps/_core/_metal_shaders.py +7286 -0
- natten_mps/_core/inverse_maps.py +428 -0
- natten_mps/_core/metal.py +1605 -0
- natten_mps/_core/ops.py +159 -0
- natten_mps/_core/pure.py +696 -0
- natten_mps/_torch_ops.py +763 -0
- natten_mps/autograd/__init__.py +15 -0
- natten_mps/autograd/_factory.py +186 -0
- natten_mps/autograd/na1d.py +9 -0
- natten_mps/autograd/na2d.py +9 -0
- natten_mps/autograd/na3d.py +9 -0
- natten_mps/compat/__init__.py +26 -0
- natten_mps/compat/v014.py +205 -0
- natten_mps/compat/v015.py +97 -0
- natten_mps/compat/v017.py +1 -0
- natten_mps/compat/v020.py +37 -0
- natten_mps/extras/__init__.py +0 -0
- natten_mps/extras/allin1/__init__.py +127 -0
- natten_mps/extras/allin1/_metal_shaders.py +803 -0
- natten_mps/extras/allin1/functional.py +570 -0
- natten_mps/extras/allin1/metal.py +331 -0
- natten_mps/extras/allin1/reference_impl.py +41 -0
- natten_mps/functional.py +1033 -0
- natten_mps/merge.py +159 -0
- natten_mps/nn/__init__.py +5 -0
- natten_mps/nn/na1d.py +130 -0
- natten_mps/nn/na2d.py +127 -0
- natten_mps/nn/na3d.py +127 -0
- natten_mps/support_matrix.py +37 -0
- natten_mps/utils/__init__.py +27 -0
- natten_mps/utils/params.py +82 -0
- natten_mps/utils/window.py +135 -0
- natten_mps/version.py +1 -0
- natten_mps-0.3.0.dist-info/METADATA +331 -0
- natten_mps-0.3.0.dist-info/RECORD +40 -0
- natten_mps-0.3.0.dist-info/WHEEL +5 -0
- natten_mps-0.3.0.dist-info/licenses/LICENSE +21 -0
- natten_mps-0.3.0.dist-info/top_level.txt +1 -0
natten_mps/__init__.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from natten_mps._core import ops
|
|
6
|
+
from natten_mps._core import metal as _metal
|
|
7
|
+
import natten_mps._torch_ops as _torch_ops # noqa: F401 — register custom ops
|
|
8
|
+
from natten_mps.functional import na1d, na1d_av, na1d_qk, na1d_varlen, na2d, na2d_av, na2d_qk, na2d_varlen, na3d, na3d_av, na3d_qk, na3d_varlen
|
|
9
|
+
from natten_mps.merge import merge_attentions
|
|
10
|
+
from natten_mps.nn import NeighborhoodAttention1D, NeighborhoodAttention2D, NeighborhoodAttention3D
|
|
11
|
+
from natten_mps.support_matrix import get_support_matrix
|
|
12
|
+
from natten_mps.version import __version__
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def has_mps() -> bool:
|
|
16
|
+
return bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available())
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def has_metal() -> bool:
|
|
20
|
+
return bool(_metal.is_available())
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_backend() -> str:
|
|
24
|
+
return ops.get_backend()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def set_backend(name: str) -> None:
|
|
28
|
+
ops.set_backend(name)
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"na1d",
|
|
32
|
+
"na1d_varlen",
|
|
33
|
+
"na2d",
|
|
34
|
+
"na2d_varlen",
|
|
35
|
+
"na3d",
|
|
36
|
+
"na3d_varlen",
|
|
37
|
+
"na1d_qk",
|
|
38
|
+
"na1d_av",
|
|
39
|
+
"na2d_qk",
|
|
40
|
+
"na2d_av",
|
|
41
|
+
"na3d_qk",
|
|
42
|
+
"na3d_av",
|
|
43
|
+
"merge_attentions",
|
|
44
|
+
"NeighborhoodAttention1D",
|
|
45
|
+
"NeighborhoodAttention2D",
|
|
46
|
+
"NeighborhoodAttention3D",
|
|
47
|
+
"has_mps",
|
|
48
|
+
"has_metal",
|
|
49
|
+
"get_backend",
|
|
50
|
+
"set_backend",
|
|
51
|
+
"get_support_matrix",
|
|
52
|
+
"__version__",
|
|
53
|
+
]
|