natten-mps 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. natten_mps/__init__.py +53 -0
  2. natten_mps/_core/__init__.py +3 -0
  3. natten_mps/_core/_metal_shaders.py +7286 -0
  4. natten_mps/_core/inverse_maps.py +428 -0
  5. natten_mps/_core/metal.py +1605 -0
  6. natten_mps/_core/ops.py +159 -0
  7. natten_mps/_core/pure.py +696 -0
  8. natten_mps/_torch_ops.py +763 -0
  9. natten_mps/autograd/__init__.py +15 -0
  10. natten_mps/autograd/_factory.py +186 -0
  11. natten_mps/autograd/na1d.py +9 -0
  12. natten_mps/autograd/na2d.py +9 -0
  13. natten_mps/autograd/na3d.py +9 -0
  14. natten_mps/compat/__init__.py +26 -0
  15. natten_mps/compat/v014.py +205 -0
  16. natten_mps/compat/v015.py +97 -0
  17. natten_mps/compat/v017.py +1 -0
  18. natten_mps/compat/v020.py +37 -0
  19. natten_mps/extras/__init__.py +0 -0
  20. natten_mps/extras/allin1/__init__.py +127 -0
  21. natten_mps/extras/allin1/_metal_shaders.py +803 -0
  22. natten_mps/extras/allin1/functional.py +570 -0
  23. natten_mps/extras/allin1/metal.py +331 -0
  24. natten_mps/extras/allin1/reference_impl.py +41 -0
  25. natten_mps/functional.py +1033 -0
  26. natten_mps/merge.py +159 -0
  27. natten_mps/nn/__init__.py +5 -0
  28. natten_mps/nn/na1d.py +130 -0
  29. natten_mps/nn/na2d.py +127 -0
  30. natten_mps/nn/na3d.py +127 -0
  31. natten_mps/support_matrix.py +37 -0
  32. natten_mps/utils/__init__.py +27 -0
  33. natten_mps/utils/params.py +82 -0
  34. natten_mps/utils/window.py +135 -0
  35. natten_mps/version.py +1 -0
  36. natten_mps-0.3.0.dist-info/METADATA +331 -0
  37. natten_mps-0.3.0.dist-info/RECORD +40 -0
  38. natten_mps-0.3.0.dist-info/WHEEL +5 -0
  39. natten_mps-0.3.0.dist-info/licenses/LICENSE +21 -0
  40. natten_mps-0.3.0.dist-info/top_level.txt +1 -0
natten_mps/__init__.py ADDED
@@ -0,0 +1,53 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+ from natten_mps._core import ops
6
+ from natten_mps._core import metal as _metal
7
+ import natten_mps._torch_ops as _torch_ops # noqa: F401 — register custom ops
8
+ from natten_mps.functional import na1d, na1d_av, na1d_qk, na1d_varlen, na2d, na2d_av, na2d_qk, na2d_varlen, na3d, na3d_av, na3d_qk, na3d_varlen
9
+ from natten_mps.merge import merge_attentions
10
+ from natten_mps.nn import NeighborhoodAttention1D, NeighborhoodAttention2D, NeighborhoodAttention3D
11
+ from natten_mps.support_matrix import get_support_matrix
12
+ from natten_mps.version import __version__
13
+
14
+
15
+ def has_mps() -> bool:
16
+ return bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available())
17
+
18
+
19
+ def has_metal() -> bool:
20
+ return bool(_metal.is_available())
21
+
22
+
23
+ def get_backend() -> str:
24
+ return ops.get_backend()
25
+
26
+
27
+ def set_backend(name: str) -> None:
28
+ ops.set_backend(name)
29
+
30
+ __all__ = [
31
+ "na1d",
32
+ "na1d_varlen",
33
+ "na2d",
34
+ "na2d_varlen",
35
+ "na3d",
36
+ "na3d_varlen",
37
+ "na1d_qk",
38
+ "na1d_av",
39
+ "na2d_qk",
40
+ "na2d_av",
41
+ "na3d_qk",
42
+ "na3d_av",
43
+ "merge_attentions",
44
+ "NeighborhoodAttention1D",
45
+ "NeighborhoodAttention2D",
46
+ "NeighborhoodAttention3D",
47
+ "has_mps",
48
+ "has_metal",
49
+ "get_backend",
50
+ "set_backend",
51
+ "get_support_matrix",
52
+ "__version__",
53
+ ]
@@ -0,0 +1,3 @@
1
+ from .ops import get_backend, set_backend
2
+
3
+ __all__ = ["get_backend", "set_backend"]