jax-mlx-plugin 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_mlx_plugin-0.0.1/CMakeLists.txt +94 -0
- jax_mlx_plugin-0.0.1/LICENSE +21 -0
- jax_mlx_plugin-0.0.1/MANIFEST.in +15 -0
- jax_mlx_plugin-0.0.1/PKG-INFO +119 -0
- jax_mlx_plugin-0.0.1/README.md +87 -0
- jax_mlx_plugin-0.0.1/pyproject.toml +41 -0
- jax_mlx_plugin-0.0.1/setup.cfg +4 -0
- jax_mlx_plugin-0.0.1/setup.py +83 -0
- jax_mlx_plugin-0.0.1/src/jax_mlx/__init__.py +6 -0
- jax_mlx_plugin-0.0.1/src/jax_mlx/parser.py +575 -0
- jax_mlx_plugin-0.0.1/src/jax_mlx/plugin.py +106 -0
- jax_mlx_plugin-0.0.1/src/jax_mlx_pjrt.cpp +5853 -0
- jax_mlx_plugin-0.0.1/src/jax_mlx_plugin.egg-info/PKG-INFO +119 -0
- jax_mlx_plugin-0.0.1/src/jax_mlx_plugin.egg-info/SOURCES.txt +21 -0
- jax_mlx_plugin-0.0.1/src/jax_mlx_plugin.egg-info/dependency_links.txt +1 -0
- jax_mlx_plugin-0.0.1/src/jax_mlx_plugin.egg-info/entry_points.txt +2 -0
- jax_mlx_plugin-0.0.1/src/jax_mlx_plugin.egg-info/not-zip-safe +1 -0
- jax_mlx_plugin-0.0.1/src/jax_mlx_plugin.egg-info/requires.txt +3 -0
- jax_mlx_plugin-0.0.1/src/jax_mlx_plugin.egg-info/top_level.txt +1 -0
- jax_mlx_plugin-0.0.1/src/mlx_mlir_parser.h +924 -0
- jax_mlx_plugin-0.0.1/src/mlx_pjrt_types.h +358 -0
- jax_mlx_plugin-0.0.1/tests/test_exhaustive.py +766 -0
- jax_mlx_plugin-0.0.1/third_party/xla/pjrt/c/pjrt_c_api.h +2925 -0
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
cmake_minimum_required(VERSION 3.15)
|
|
2
|
+
project(jax_mlx)
|
|
3
|
+
|
|
4
|
+
set(CMAKE_CXX_STANDARD 17)
|
|
5
|
+
set(CMAKE_OSX_DEPLOYMENT_TARGET "26.0")
|
|
6
|
+
|
|
7
|
+
# --- Dependencies ---
|
|
8
|
+
|
|
9
|
+
# 1. Find Python and pybind11
|
|
10
|
+
find_package(Python3 REQUIRED COMPONENTS Interpreter Development)
|
|
11
|
+
|
|
12
|
+
# Get pybind11 cmake path from pip install
|
|
13
|
+
execute_process(
|
|
14
|
+
COMMAND ${Python3_EXECUTABLE} -c "import pybind11; print(pybind11.get_cmake_dir())"
|
|
15
|
+
OUTPUT_VARIABLE PYBIND11_CMAKE_DIR
|
|
16
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
17
|
+
RESULT_VARIABLE PYBIND11_FIND_RESULT
|
|
18
|
+
)
|
|
19
|
+
if(NOT PYBIND11_FIND_RESULT EQUAL 0)
|
|
20
|
+
message(FATAL_ERROR "Could not find pybind11. Install with: pip install pybind11")
|
|
21
|
+
endif()
|
|
22
|
+
list(APPEND CMAKE_PREFIX_PATH "${PYBIND11_CMAKE_DIR}")
|
|
23
|
+
|
|
24
|
+
find_package(pybind11 REQUIRED)
|
|
25
|
+
|
|
26
|
+
# 2. Find MLX from pip installation
|
|
27
|
+
# The mlx package installs headers in site-packages/mlx/include
|
|
28
|
+
# and libmlx.dylib in site-packages/mlx/lib
|
|
29
|
+
execute_process(
|
|
30
|
+
COMMAND ${Python3_EXECUTABLE} -c "import mlx.core; import os; print(os.path.dirname(mlx.core.__file__))"
|
|
31
|
+
OUTPUT_VARIABLE MLX_PYTHON_PATH
|
|
32
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
33
|
+
RESULT_VARIABLE MLX_FIND_RESULT
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
if(NOT MLX_FIND_RESULT EQUAL 0)
|
|
37
|
+
message(FATAL_ERROR "Could not find MLX Python package. Install with: pip install mlx")
|
|
38
|
+
endif()
|
|
39
|
+
|
|
40
|
+
set(MLX_INCLUDE_DIR "${MLX_PYTHON_PATH}/include")
|
|
41
|
+
set(MLX_LIB_DIR "${MLX_PYTHON_PATH}/lib")
|
|
42
|
+
|
|
43
|
+
message(STATUS "Found MLX at: ${MLX_PYTHON_PATH}")
|
|
44
|
+
message(STATUS " Include: ${MLX_INCLUDE_DIR}")
|
|
45
|
+
message(STATUS " Lib: ${MLX_LIB_DIR}")
|
|
46
|
+
|
|
47
|
+
# Verify paths exist
|
|
48
|
+
if(NOT EXISTS "${MLX_INCLUDE_DIR}")
|
|
49
|
+
message(FATAL_ERROR "MLX include directory not found: ${MLX_INCLUDE_DIR}")
|
|
50
|
+
endif()
|
|
51
|
+
if(NOT EXISTS "${MLX_LIB_DIR}/libmlx.dylib")
|
|
52
|
+
message(FATAL_ERROR "MLX library not found: ${MLX_LIB_DIR}/libmlx.dylib")
|
|
53
|
+
endif()
|
|
54
|
+
|
|
55
|
+
# 3. OpenXLA Headers (for PJRT C API)
|
|
56
|
+
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party)
|
|
57
|
+
|
|
58
|
+
# --- The Plugin Library ---
|
|
59
|
+
add_library(mlx_pjrt_plugin SHARED
|
|
60
|
+
src/jax_mlx_pjrt.cpp
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Include directories
|
|
64
|
+
target_include_directories(mlx_pjrt_plugin PRIVATE
|
|
65
|
+
${MLX_INCLUDE_DIR}
|
|
66
|
+
${Python3_INCLUDE_DIRS}
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Link against MLX and pybind11
|
|
70
|
+
target_link_libraries(mlx_pjrt_plugin
|
|
71
|
+
PRIVATE
|
|
72
|
+
${MLX_LIB_DIR}/libmlx.dylib
|
|
73
|
+
pybind11::pybind11
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Set RPATH to find libmlx.dylib at runtime
|
|
77
|
+
# Plugin installs to site-packages/jax_mlx/
|
|
78
|
+
# MLX is at site-packages/mlx/lib/
|
|
79
|
+
# So relative path is @loader_path/../mlx/lib
|
|
80
|
+
set_target_properties(mlx_pjrt_plugin PROPERTIES
|
|
81
|
+
INSTALL_RPATH "@loader_path;@loader_path/../mlx/lib"
|
|
82
|
+
BUILD_WITH_INSTALL_RPATH TRUE
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# macOS specific: Allow undefined symbols (Python provides them at runtime)
|
|
86
|
+
if (APPLE)
|
|
87
|
+
target_link_options(mlx_pjrt_plugin PRIVATE "-undefined" "dynamic_lookup")
|
|
88
|
+
endif()
|
|
89
|
+
|
|
90
|
+
# Install target
|
|
91
|
+
install(TARGETS mlx_pjrt_plugin
|
|
92
|
+
LIBRARY DESTINATION jax_mlx
|
|
93
|
+
RUNTIME DESTINATION jax_mlx
|
|
94
|
+
)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Thomas Summe
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Include CMake and C++ source files for building the extension
|
|
2
|
+
include CMakeLists.txt
|
|
3
|
+
recursive-include src *.cpp *.h
|
|
4
|
+
recursive-include third_party *.h
|
|
5
|
+
|
|
6
|
+
# Include documentation
|
|
7
|
+
include README.md
|
|
8
|
+
include LICENSE
|
|
9
|
+
|
|
10
|
+
# Exclude build artifacts
|
|
11
|
+
global-exclude __pycache__
|
|
12
|
+
global-exclude *.py[cod]
|
|
13
|
+
global-exclude *.so
|
|
14
|
+
global-exclude *.dylib
|
|
15
|
+
prune benchmarks/build
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: jax-mlx-plugin
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: JAX PJRT plugin for Apple Silicon using MLX
|
|
5
|
+
Home-page: https://github.com/tsumme1/jax-mlx
|
|
6
|
+
Author: Thomas Summe
|
|
7
|
+
License: MIT
|
|
8
|
+
Project-URL: Homepage, https://github.com/tsumme1/jax-mlx
|
|
9
|
+
Project-URL: Repository, https://github.com/tsumme1/jax-mlx
|
|
10
|
+
Keywords: jax,mlx,apple-silicon,machine-learning,pjrt
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Environment :: GPU
|
|
13
|
+
Classifier: Environment :: MacOS X
|
|
14
|
+
Classifier: Intended Audience :: Developers
|
|
15
|
+
Classifier: Intended Audience :: Science/Research
|
|
16
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
17
|
+
Classifier: Operating System :: MacOS :: MacOS X
|
|
18
|
+
Classifier: Programming Language :: Python :: 3
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
21
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
22
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
23
|
+
Requires-Python: >=3.11
|
|
24
|
+
Description-Content-Type: text/markdown
|
|
25
|
+
License-File: LICENSE
|
|
26
|
+
Requires-Dist: jax>=0.5.0
|
|
27
|
+
Requires-Dist: jaxlib>=0.5.0
|
|
28
|
+
Requires-Dist: mlx
|
|
29
|
+
Dynamic: home-page
|
|
30
|
+
Dynamic: license-file
|
|
31
|
+
Dynamic: requires-python
|
|
32
|
+
|
|
33
|
+
# JAX MLX Plugin
|
|
34
|
+
|
|
35
|
+
A PJRT plugin enabling JAX to use Apple's MLX framework as a backend on Apple Silicon Macs.
|
|
36
|
+
|
|
37
|
+
## Status
|
|
38
|
+
|
|
39
|
+
✅ **362 ops tested and passing**
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
## Requirements
|
|
43
|
+
|
|
44
|
+
- **Apple Silicon Mac** (M1/M2/M3/M4)
|
|
45
|
+
- **Python:** 3.11+
|
|
46
|
+
- **Dependencies:** jax, jaxlib, mlx
|
|
47
|
+
|
|
48
|
+
## Installation
|
|
49
|
+
|
|
50
|
+
```bash
|
|
51
|
+
# Install build dependencies
|
|
52
|
+
pip install mlx jaxlib jax
|
|
53
|
+
|
|
54
|
+
# Install the plugin
|
|
55
|
+
pip install .
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
## Usage
|
|
59
|
+
|
|
60
|
+
```python
|
|
61
|
+
import jax
|
|
62
|
+
import jax.numpy as jnp
|
|
63
|
+
|
|
64
|
+
# List available devices
|
|
65
|
+
print(jax.devices()) # [mlx:0]
|
|
66
|
+
|
|
67
|
+
# Use MLX as default device
|
|
68
|
+
mlx = jax.devices('mlx')[0]
|
|
69
|
+
with jax.default_device(mlx):
|
|
70
|
+
x = jnp.array([1.0, 2.0, 3.0])
|
|
71
|
+
y = jnp.sin(x) + jnp.cos(x)
|
|
72
|
+
print(y)
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
## Features
|
|
76
|
+
|
|
77
|
+
- ✅ All core JAX operations (362 tested)
|
|
78
|
+
- ✅ Full autodiff support (`jax.grad`, `jax.value_and_grad`)
|
|
79
|
+
- ✅ JIT compilation with `mx.compile()` kernel fusion
|
|
80
|
+
- ✅ Vectorization (`jax.vmap`)
|
|
81
|
+
- ✅ Control flow (`lax.cond`, `lax.while_loop`, `lax.scan`)
|
|
82
|
+
- ✅ Linear algebra, FFT, convolutions
|
|
83
|
+
- ✅ Neural network training (Flax, Optax)
|
|
84
|
+
|
|
85
|
+
## Environment Variables
|
|
86
|
+
|
|
87
|
+
| Variable | Description |
|
|
88
|
+
|----------|-------------|
|
|
89
|
+
| `MLX_PJRT_DEBUG=1` | Enable verbose debug logging |
|
|
90
|
+
| `MLX_NO_COMPILE=1` | Disable mx.compile() kernel fusion |
|
|
91
|
+
| `MLX_TIMING=1` | Enable timing output |
|
|
92
|
+
|
|
93
|
+
## Development
|
|
94
|
+
|
|
95
|
+
```bash
|
|
96
|
+
# Run exhaustive tests (362 ops)
|
|
97
|
+
python tests/test_exhaustive.py
|
|
98
|
+
|
|
99
|
+
# Run CNN benchmarks
|
|
100
|
+
python benchmarks/benchmark_cnn.py # JAX/Flax
|
|
101
|
+
python benchmarks/benchmark_mlx_native.py # Native MLX Python
|
|
102
|
+
|
|
103
|
+
# Build and run C++ benchmark
|
|
104
|
+
cd benchmarks && cmake -B build && cmake --build build
|
|
105
|
+
./build/mlx_cpp_benchmark
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
## Architecture
|
|
109
|
+
|
|
110
|
+
The plugin implements the PJRT (Portable JAX Runtime) C API. StableHLO operations from JAX are parsed and converted to MLX operations at runtime using a lightweight MLIR parser. The plugin uses `mx.compile()` for GPU kernel fusion.
|
|
111
|
+
|
|
112
|
+
## Known Limitations
|
|
113
|
+
|
|
114
|
+
- **Float64:** Not supported on Metal GPU (use Float32)
|
|
115
|
+
- **While loops:** Block kernel fusion (require runtime eval)
|
|
116
|
+
|
|
117
|
+
## License
|
|
118
|
+
|
|
119
|
+
MIT License - see [LICENSE](LICENSE) for details.
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# JAX MLX Plugin
|
|
2
|
+
|
|
3
|
+
A PJRT plugin enabling JAX to use Apple's MLX framework as a backend on Apple Silicon Macs.
|
|
4
|
+
|
|
5
|
+
## Status
|
|
6
|
+
|
|
7
|
+
✅ **362 ops tested and passing**
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
## Requirements
|
|
11
|
+
|
|
12
|
+
- **Apple Silicon Mac** (M1/M2/M3/M4)
|
|
13
|
+
- **Python:** 3.11+
|
|
14
|
+
- **Dependencies:** jax, jaxlib, mlx
|
|
15
|
+
|
|
16
|
+
## Installation
|
|
17
|
+
|
|
18
|
+
```bash
|
|
19
|
+
# Install build dependencies
|
|
20
|
+
pip install mlx jaxlib jax
|
|
21
|
+
|
|
22
|
+
# Install the plugin
|
|
23
|
+
pip install .
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
## Usage
|
|
27
|
+
|
|
28
|
+
```python
|
|
29
|
+
import jax
|
|
30
|
+
import jax.numpy as jnp
|
|
31
|
+
|
|
32
|
+
# List available devices
|
|
33
|
+
print(jax.devices()) # [mlx:0]
|
|
34
|
+
|
|
35
|
+
# Use MLX as default device
|
|
36
|
+
mlx = jax.devices('mlx')[0]
|
|
37
|
+
with jax.default_device(mlx):
|
|
38
|
+
x = jnp.array([1.0, 2.0, 3.0])
|
|
39
|
+
y = jnp.sin(x) + jnp.cos(x)
|
|
40
|
+
print(y)
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
## Features
|
|
44
|
+
|
|
45
|
+
- ✅ All core JAX operations (362 tested)
|
|
46
|
+
- ✅ Full autodiff support (`jax.grad`, `jax.value_and_grad`)
|
|
47
|
+
- ✅ JIT compilation with `mx.compile()` kernel fusion
|
|
48
|
+
- ✅ Vectorization (`jax.vmap`)
|
|
49
|
+
- ✅ Control flow (`lax.cond`, `lax.while_loop`, `lax.scan`)
|
|
50
|
+
- ✅ Linear algebra, FFT, convolutions
|
|
51
|
+
- ✅ Neural network training (Flax, Optax)
|
|
52
|
+
|
|
53
|
+
## Environment Variables
|
|
54
|
+
|
|
55
|
+
| Variable | Description |
|
|
56
|
+
|----------|-------------|
|
|
57
|
+
| `MLX_PJRT_DEBUG=1` | Enable verbose debug logging |
|
|
58
|
+
| `MLX_NO_COMPILE=1` | Disable mx.compile() kernel fusion |
|
|
59
|
+
| `MLX_TIMING=1` | Enable timing output |
|
|
60
|
+
|
|
61
|
+
## Development
|
|
62
|
+
|
|
63
|
+
```bash
|
|
64
|
+
# Run exhaustive tests (362 ops)
|
|
65
|
+
python tests/test_exhaustive.py
|
|
66
|
+
|
|
67
|
+
# Run CNN benchmarks
|
|
68
|
+
python benchmarks/benchmark_cnn.py # JAX/Flax
|
|
69
|
+
python benchmarks/benchmark_mlx_native.py # Native MLX Python
|
|
70
|
+
|
|
71
|
+
# Build and run C++ benchmark
|
|
72
|
+
cd benchmarks && cmake -B build && cmake --build build
|
|
73
|
+
./build/mlx_cpp_benchmark
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
## Architecture
|
|
77
|
+
|
|
78
|
+
The plugin implements the PJRT (Portable JAX Runtime) C API. StableHLO operations from JAX are parsed and converted to MLX operations at runtime using a lightweight MLIR parser. The plugin uses `mx.compile()` for GPU kernel fusion.
|
|
79
|
+
|
|
80
|
+
## Known Limitations
|
|
81
|
+
|
|
82
|
+
- **Float64:** Not supported on Metal GPU (use Float32)
|
|
83
|
+
- **While loops:** Block kernel fusion (require runtime eval)
|
|
84
|
+
|
|
85
|
+
## License
|
|
86
|
+
|
|
87
|
+
MIT License - see [LICENSE](LICENSE) for details.
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=42", "wheel", "pybind11", "cmake", "mlx"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "jax-mlx-plugin"
|
|
7
|
+
version = "0.0.1"
|
|
8
|
+
description = "JAX PJRT plugin for Apple Silicon using MLX"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = {text = "MIT"}
|
|
11
|
+
requires-python = ">=3.11"
|
|
12
|
+
authors = [
|
|
13
|
+
{name = "Thomas Summe"}
|
|
14
|
+
]
|
|
15
|
+
classifiers = [
|
|
16
|
+
"Development Status :: 3 - Alpha",
|
|
17
|
+
"Environment :: GPU",
|
|
18
|
+
"Environment :: MacOS X",
|
|
19
|
+
"Intended Audience :: Developers",
|
|
20
|
+
"Intended Audience :: Science/Research",
|
|
21
|
+
"License :: OSI Approved :: MIT License",
|
|
22
|
+
"Operating System :: MacOS :: MacOS X",
|
|
23
|
+
"Programming Language :: Python :: 3",
|
|
24
|
+
"Programming Language :: Python :: 3.11",
|
|
25
|
+
"Programming Language :: Python :: 3.12",
|
|
26
|
+
"Programming Language :: Python :: 3.13",
|
|
27
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
28
|
+
]
|
|
29
|
+
keywords = ["jax", "mlx", "apple-silicon", "machine-learning", "pjrt"]
|
|
30
|
+
dependencies = [
|
|
31
|
+
"jax>=0.5.0",
|
|
32
|
+
"jaxlib>=0.5.0",
|
|
33
|
+
"mlx",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
[project.urls]
|
|
37
|
+
Homepage = "https://github.com/tsumme1/jax-mlx"
|
|
38
|
+
Repository = "https://github.com/tsumme1/jax-mlx"
|
|
39
|
+
|
|
40
|
+
[project.entry-points."jax_plugins"]
|
|
41
|
+
mlx_plugin = "jax_mlx.plugin"
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import subprocess
|
|
4
|
+
import shutil
|
|
5
|
+
from setuptools import setup, Extension, find_packages
|
|
6
|
+
from setuptools.command.build_ext import build_ext
|
|
7
|
+
|
|
8
|
+
class CMakeExtension(Extension):
|
|
9
|
+
def __init__(self, name, sourcedir=''):
|
|
10
|
+
Extension.__init__(self, name, sources=[])
|
|
11
|
+
self.sourcedir = os.path.abspath(sourcedir)
|
|
12
|
+
|
|
13
|
+
class CMakeBuild(build_ext):
|
|
14
|
+
def run(self):
|
|
15
|
+
for ext in self.extensions:
|
|
16
|
+
self.build_extension(ext)
|
|
17
|
+
|
|
18
|
+
def build_extension(self, ext):
|
|
19
|
+
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
|
|
20
|
+
if not extdir.endswith(os.path.sep):
|
|
21
|
+
extdir += os.path.sep
|
|
22
|
+
|
|
23
|
+
cfg = 'Debug' if self.debug else 'Release'
|
|
24
|
+
|
|
25
|
+
cmake_args = [
|
|
26
|
+
f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}',
|
|
27
|
+
f'-DPython3_EXECUTABLE={sys.executable}',
|
|
28
|
+
f'-DCMAKE_BUILD_TYPE={cfg}',
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
build_args = ['--config', cfg]
|
|
32
|
+
|
|
33
|
+
if not os.path.exists(self.build_temp):
|
|
34
|
+
os.makedirs(self.build_temp)
|
|
35
|
+
|
|
36
|
+
subprocess.check_call(['cmake', ext.sourcedir] + cmake_args, cwd=self.build_temp)
|
|
37
|
+
subprocess.check_call(['cmake', '--build', '.', '-j'] + build_args, cwd=self.build_temp)
|
|
38
|
+
|
|
39
|
+
# Copy the built library to the package source directory as well
|
|
40
|
+
src_package_dir = os.path.join(ext.sourcedir, 'src', 'jax_mlx')
|
|
41
|
+
built_lib = os.path.join(extdir, 'libmlx_pjrt_plugin.dylib')
|
|
42
|
+
if os.path.exists(built_lib) and os.path.isdir(src_package_dir):
|
|
43
|
+
shutil.copy2(built_lib, src_package_dir)
|
|
44
|
+
|
|
45
|
+
setup(
|
|
46
|
+
name='jax-mlx-plugin',
|
|
47
|
+
version='0.0.1',
|
|
48
|
+
author='Thomas Summe',
|
|
49
|
+
description='JAX PJRT plugin for Apple Silicon using MLX',
|
|
50
|
+
long_description=open('README.md').read(),
|
|
51
|
+
long_description_content_type='text/markdown',
|
|
52
|
+
url='https://github.com/tsumme1/jax-mlx',
|
|
53
|
+
packages=find_packages(where='src'),
|
|
54
|
+
package_dir={'': 'src'},
|
|
55
|
+
ext_modules=[CMakeExtension('jax_mlx.mlx_pjrt_plugin')],
|
|
56
|
+
cmdclass=dict(build_ext=CMakeBuild),
|
|
57
|
+
package_data={'jax_mlx': ['*.dylib', '*.so']},
|
|
58
|
+
include_package_data=True,
|
|
59
|
+
zip_safe=False,
|
|
60
|
+
python_requires='>=3.11',
|
|
61
|
+
install_requires=[
|
|
62
|
+
'jax>=0.5.0',
|
|
63
|
+
'jaxlib>=0.5.0',
|
|
64
|
+
'mlx',
|
|
65
|
+
],
|
|
66
|
+
entry_points={
|
|
67
|
+
"jax_plugins": [
|
|
68
|
+
"mlx_plugin = jax_mlx.plugin",
|
|
69
|
+
],
|
|
70
|
+
},
|
|
71
|
+
classifiers=[
|
|
72
|
+
'Development Status :: 3 - Alpha',
|
|
73
|
+
'Intended Audience :: Developers',
|
|
74
|
+
'Intended Audience :: Science/Research',
|
|
75
|
+
'License :: OSI Approved :: MIT License',
|
|
76
|
+
'Operating System :: MacOS :: MacOS X',
|
|
77
|
+
'Programming Language :: Python :: 3',
|
|
78
|
+
'Programming Language :: Python :: 3.11',
|
|
79
|
+
'Programming Language :: Python :: 3.12',
|
|
80
|
+
'Programming Language :: Python :: 3.13',
|
|
81
|
+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
|
82
|
+
],
|
|
83
|
+
)
|