jax-mlx-plugin 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,94 @@
1
+ cmake_minimum_required(VERSION 3.15)
2
+ project(jax_mlx)
3
+
4
+ set(CMAKE_CXX_STANDARD 17)
5
+ set(CMAKE_OSX_DEPLOYMENT_TARGET "26.0")
6
+
7
+ # --- Dependencies ---
8
+
9
+ # 1. Find Python and pybind11
10
+ find_package(Python3 REQUIRED COMPONENTS Interpreter Development)
11
+
12
+ # Get pybind11 cmake path from pip install
13
+ execute_process(
14
+ COMMAND ${Python3_EXECUTABLE} -c "import pybind11; print(pybind11.get_cmake_dir())"
15
+ OUTPUT_VARIABLE PYBIND11_CMAKE_DIR
16
+ OUTPUT_STRIP_TRAILING_WHITESPACE
17
+ RESULT_VARIABLE PYBIND11_FIND_RESULT
18
+ )
19
+ if(NOT PYBIND11_FIND_RESULT EQUAL 0)
20
+ message(FATAL_ERROR "Could not find pybind11. Install with: pip install pybind11")
21
+ endif()
22
+ list(APPEND CMAKE_PREFIX_PATH "${PYBIND11_CMAKE_DIR}")
23
+
24
+ find_package(pybind11 REQUIRED)
25
+
26
+ # 2. Find MLX from pip installation
27
+ # The mlx package installs headers in site-packages/mlx/include
28
+ # and libmlx.dylib in site-packages/mlx/lib
29
+ execute_process(
30
+ COMMAND ${Python3_EXECUTABLE} -c "import mlx.core; import os; print(os.path.dirname(mlx.core.__file__))"
31
+ OUTPUT_VARIABLE MLX_PYTHON_PATH
32
+ OUTPUT_STRIP_TRAILING_WHITESPACE
33
+ RESULT_VARIABLE MLX_FIND_RESULT
34
+ )
35
+
36
+ if(NOT MLX_FIND_RESULT EQUAL 0)
37
+ message(FATAL_ERROR "Could not find MLX Python package. Install with: pip install mlx")
38
+ endif()
39
+
40
+ set(MLX_INCLUDE_DIR "${MLX_PYTHON_PATH}/include")
41
+ set(MLX_LIB_DIR "${MLX_PYTHON_PATH}/lib")
42
+
43
+ message(STATUS "Found MLX at: ${MLX_PYTHON_PATH}")
44
+ message(STATUS " Include: ${MLX_INCLUDE_DIR}")
45
+ message(STATUS " Lib: ${MLX_LIB_DIR}")
46
+
47
+ # Verify paths exist
48
+ if(NOT EXISTS "${MLX_INCLUDE_DIR}")
49
+ message(FATAL_ERROR "MLX include directory not found: ${MLX_INCLUDE_DIR}")
50
+ endif()
51
+ if(NOT EXISTS "${MLX_LIB_DIR}/libmlx.dylib")
52
+ message(FATAL_ERROR "MLX library not found: ${MLX_LIB_DIR}/libmlx.dylib")
53
+ endif()
54
+
55
+ # 3. OpenXLA Headers (for PJRT C API)
56
+ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party)
57
+
58
+ # --- The Plugin Library ---
59
+ add_library(mlx_pjrt_plugin SHARED
60
+ src/jax_mlx_pjrt.cpp
61
+ )
62
+
63
+ # Include directories
64
+ target_include_directories(mlx_pjrt_plugin PRIVATE
65
+ ${MLX_INCLUDE_DIR}
66
+ ${Python3_INCLUDE_DIRS}
67
+ )
68
+
69
+ # Link against MLX and pybind11
70
+ target_link_libraries(mlx_pjrt_plugin
71
+ PRIVATE
72
+ ${MLX_LIB_DIR}/libmlx.dylib
73
+ pybind11::pybind11
74
+ )
75
+
76
+ # Set RPATH to find libmlx.dylib at runtime
77
+ # Plugin installs to site-packages/jax_mlx/
78
+ # MLX is at site-packages/mlx/lib/
79
+ # So relative path is @loader_path/../mlx/lib
80
+ set_target_properties(mlx_pjrt_plugin PROPERTIES
81
+ INSTALL_RPATH "@loader_path;@loader_path/../mlx/lib"
82
+ BUILD_WITH_INSTALL_RPATH TRUE
83
+ )
84
+
85
+ # macOS specific: Allow undefined symbols (Python provides them at runtime)
86
+ if (APPLE)
87
+ target_link_options(mlx_pjrt_plugin PRIVATE "-undefined" "dynamic_lookup")
88
+ endif()
89
+
90
+ # Install target
91
+ install(TARGETS mlx_pjrt_plugin
92
+ LIBRARY DESTINATION jax_mlx
93
+ RUNTIME DESTINATION jax_mlx
94
+ )
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Thomas Summe
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,15 @@
1
+ # Include CMake and C++ source files for building the extension
2
+ include CMakeLists.txt
3
+ recursive-include src *.cpp *.h
4
+ recursive-include third_party *.h
5
+
6
+ # Include documentation
7
+ include README.md
8
+ include LICENSE
9
+
10
+ # Exclude build artifacts
11
+ global-exclude __pycache__
12
+ global-exclude *.py[cod]
13
+ global-exclude *.so
14
+ global-exclude *.dylib
15
+ prune benchmarks/build
@@ -0,0 +1,119 @@
1
+ Metadata-Version: 2.4
2
+ Name: jax-mlx-plugin
3
+ Version: 0.0.1
4
+ Summary: JAX PJRT plugin for Apple Silicon using MLX
5
+ Home-page: https://github.com/tsumme1/jax-mlx
6
+ Author: Thomas Summe
7
+ License: MIT
8
+ Project-URL: Homepage, https://github.com/tsumme1/jax-mlx
9
+ Project-URL: Repository, https://github.com/tsumme1/jax-mlx
10
+ Keywords: jax,mlx,apple-silicon,machine-learning,pjrt
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Environment :: GPU
13
+ Classifier: Environment :: MacOS X
14
+ Classifier: Intended Audience :: Developers
15
+ Classifier: Intended Audience :: Science/Research
16
+ Classifier: License :: OSI Approved :: MIT License
17
+ Classifier: Operating System :: MacOS :: MacOS X
18
+ Classifier: Programming Language :: Python :: 3
19
+ Classifier: Programming Language :: Python :: 3.11
20
+ Classifier: Programming Language :: Python :: 3.12
21
+ Classifier: Programming Language :: Python :: 3.13
22
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
23
+ Requires-Python: >=3.11
24
+ Description-Content-Type: text/markdown
25
+ License-File: LICENSE
26
+ Requires-Dist: jax>=0.5.0
27
+ Requires-Dist: jaxlib>=0.5.0
28
+ Requires-Dist: mlx
29
+ Dynamic: home-page
30
+ Dynamic: license-file
31
+ Dynamic: requires-python
32
+
33
+ # JAX MLX Plugin
34
+
35
+ A PJRT plugin enabling JAX to use Apple's MLX framework as a backend on Apple Silicon Macs.
36
+
37
+ ## Status
38
+
39
+ ✅ **362 ops tested and passing**
40
+
41
+
42
+ ## Requirements
43
+
44
+ - **Apple Silicon Mac** (M1/M2/M3/M4)
45
+ - **Python:** 3.11+
46
+ - **Dependencies:** jax, jaxlib, mlx
47
+
48
+ ## Installation
49
+
50
+ ```bash
51
+ # Install build dependencies
52
+ pip install mlx jaxlib jax
53
+
54
+ # Install the plugin
55
+ pip install .
56
+ ```
57
+
58
+ ## Usage
59
+
60
+ ```python
61
+ import jax
62
+ import jax.numpy as jnp
63
+
64
+ # List available devices
65
+ print(jax.devices()) # [mlx:0]
66
+
67
+ # Use MLX as default device
68
+ mlx = jax.devices('mlx')[0]
69
+ with jax.default_device(mlx):
70
+ x = jnp.array([1.0, 2.0, 3.0])
71
+ y = jnp.sin(x) + jnp.cos(x)
72
+ print(y)
73
+ ```
74
+
75
+ ## Features
76
+
77
+ - ✅ All core JAX operations (362 tested)
78
+ - ✅ Full autodiff support (`jax.grad`, `jax.value_and_grad`)
79
+ - ✅ JIT compilation with `mx.compile()` kernel fusion
80
+ - ✅ Vectorization (`jax.vmap`)
81
+ - ✅ Control flow (`lax.cond`, `lax.while_loop`, `lax.scan`)
82
+ - ✅ Linear algebra, FFT, convolutions
83
+ - ✅ Neural network training (Flax, Optax)
84
+
85
+ ## Environment Variables
86
+
87
+ | Variable | Description |
88
+ |----------|-------------|
89
+ | `MLX_PJRT_DEBUG=1` | Enable verbose debug logging |
90
+ | `MLX_NO_COMPILE=1` | Disable mx.compile() kernel fusion |
91
+ | `MLX_TIMING=1` | Enable timing output |
92
+
93
+ ## Development
94
+
95
+ ```bash
96
+ # Run exhaustive tests (362 ops)
97
+ python tests/test_exhaustive.py
98
+
99
+ # Run CNN benchmarks
100
+ python benchmarks/benchmark_cnn.py # JAX/Flax
101
+ python benchmarks/benchmark_mlx_native.py # Native MLX Python
102
+
103
+ # Build and run C++ benchmark
104
+ cd benchmarks && cmake -B build && cmake --build build
105
+ ./build/mlx_cpp_benchmark
106
+ ```
107
+
108
+ ## Architecture
109
+
110
+ The plugin implements the PJRT (Portable JAX Runtime) C API. StableHLO operations from JAX are parsed and converted to MLX operations at runtime using a lightweight MLIR parser. The plugin uses `mx.compile()` for GPU kernel fusion.
111
+
112
+ ## Known Limitations
113
+
114
+ - **Float64:** Not supported on Metal GPU (use Float32)
115
+ - **While loops:** Block kernel fusion (require runtime eval)
116
+
117
+ ## License
118
+
119
+ MIT License - see [LICENSE](LICENSE) for details.
@@ -0,0 +1,87 @@
1
+ # JAX MLX Plugin
2
+
3
+ A PJRT plugin enabling JAX to use Apple's MLX framework as a backend on Apple Silicon Macs.
4
+
5
+ ## Status
6
+
7
+ ✅ **362 ops tested and passing**
8
+
9
+
10
+ ## Requirements
11
+
12
+ - **Apple Silicon Mac** (M1/M2/M3/M4)
13
+ - **Python:** 3.11+
14
+ - **Dependencies:** jax, jaxlib, mlx
15
+
16
+ ## Installation
17
+
18
+ ```bash
19
+ # Install build dependencies
20
+ pip install mlx jaxlib jax
21
+
22
+ # Install the plugin
23
+ pip install .
24
+ ```
25
+
26
+ ## Usage
27
+
28
+ ```python
29
+ import jax
30
+ import jax.numpy as jnp
31
+
32
+ # List available devices
33
+ print(jax.devices()) # [mlx:0]
34
+
35
+ # Use MLX as default device
36
+ mlx = jax.devices('mlx')[0]
37
+ with jax.default_device(mlx):
38
+ x = jnp.array([1.0, 2.0, 3.0])
39
+ y = jnp.sin(x) + jnp.cos(x)
40
+ print(y)
41
+ ```
42
+
43
+ ## Features
44
+
45
+ - ✅ All core JAX operations (362 tested)
46
+ - ✅ Full autodiff support (`jax.grad`, `jax.value_and_grad`)
47
+ - ✅ JIT compilation with `mx.compile()` kernel fusion
48
+ - ✅ Vectorization (`jax.vmap`)
49
+ - ✅ Control flow (`lax.cond`, `lax.while_loop`, `lax.scan`)
50
+ - ✅ Linear algebra, FFT, convolutions
51
+ - ✅ Neural network training (Flax, Optax)
52
+
53
+ ## Environment Variables
54
+
55
+ | Variable | Description |
56
+ |----------|-------------|
57
+ | `MLX_PJRT_DEBUG=1` | Enable verbose debug logging |
58
+ | `MLX_NO_COMPILE=1` | Disable mx.compile() kernel fusion |
59
+ | `MLX_TIMING=1` | Enable timing output |
60
+
61
+ ## Development
62
+
63
+ ```bash
64
+ # Run exhaustive tests (362 ops)
65
+ python tests/test_exhaustive.py
66
+
67
+ # Run CNN benchmarks
68
+ python benchmarks/benchmark_cnn.py # JAX/Flax
69
+ python benchmarks/benchmark_mlx_native.py # Native MLX Python
70
+
71
+ # Build and run C++ benchmark
72
+ cd benchmarks && cmake -B build && cmake --build build
73
+ ./build/mlx_cpp_benchmark
74
+ ```
75
+
76
+ ## Architecture
77
+
78
+ The plugin implements the PJRT (Portable JAX Runtime) C API. StableHLO operations from JAX are parsed and converted to MLX operations at runtime using a lightweight MLIR parser. The plugin uses `mx.compile()` for GPU kernel fusion.
79
+
80
+ ## Known Limitations
81
+
82
+ - **Float64:** Not supported on Metal GPU (use Float32)
83
+ - **While loops:** Block kernel fusion (require runtime eval)
84
+
85
+ ## License
86
+
87
+ MIT License - see [LICENSE](LICENSE) for details.
@@ -0,0 +1,41 @@
1
+ [build-system]
2
+ requires = ["setuptools>=42", "wheel", "pybind11", "cmake", "mlx"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "jax-mlx-plugin"
7
+ version = "0.0.1"
8
+ description = "JAX PJRT plugin for Apple Silicon using MLX"
9
+ readme = "README.md"
10
+ license = {text = "MIT"}
11
+ requires-python = ">=3.11"
12
+ authors = [
13
+ {name = "Thomas Summe"}
14
+ ]
15
+ classifiers = [
16
+ "Development Status :: 3 - Alpha",
17
+ "Environment :: GPU",
18
+ "Environment :: MacOS X",
19
+ "Intended Audience :: Developers",
20
+ "Intended Audience :: Science/Research",
21
+ "License :: OSI Approved :: MIT License",
22
+ "Operating System :: MacOS :: MacOS X",
23
+ "Programming Language :: Python :: 3",
24
+ "Programming Language :: Python :: 3.11",
25
+ "Programming Language :: Python :: 3.12",
26
+ "Programming Language :: Python :: 3.13",
27
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
28
+ ]
29
+ keywords = ["jax", "mlx", "apple-silicon", "machine-learning", "pjrt"]
30
+ dependencies = [
31
+ "jax>=0.5.0",
32
+ "jaxlib>=0.5.0",
33
+ "mlx",
34
+ ]
35
+
36
+ [project.urls]
37
+ Homepage = "https://github.com/tsumme1/jax-mlx"
38
+ Repository = "https://github.com/tsumme1/jax-mlx"
39
+
40
+ [project.entry-points."jax_plugins"]
41
+ mlx_plugin = "jax_mlx.plugin"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,83 @@
1
+ import os
2
+ import sys
3
+ import subprocess
4
+ import shutil
5
+ from setuptools import setup, Extension, find_packages
6
+ from setuptools.command.build_ext import build_ext
7
+
8
+ class CMakeExtension(Extension):
9
+ def __init__(self, name, sourcedir=''):
10
+ Extension.__init__(self, name, sources=[])
11
+ self.sourcedir = os.path.abspath(sourcedir)
12
+
13
+ class CMakeBuild(build_ext):
14
+ def run(self):
15
+ for ext in self.extensions:
16
+ self.build_extension(ext)
17
+
18
+ def build_extension(self, ext):
19
+ extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
20
+ if not extdir.endswith(os.path.sep):
21
+ extdir += os.path.sep
22
+
23
+ cfg = 'Debug' if self.debug else 'Release'
24
+
25
+ cmake_args = [
26
+ f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}',
27
+ f'-DPython3_EXECUTABLE={sys.executable}',
28
+ f'-DCMAKE_BUILD_TYPE={cfg}',
29
+ ]
30
+
31
+ build_args = ['--config', cfg]
32
+
33
+ if not os.path.exists(self.build_temp):
34
+ os.makedirs(self.build_temp)
35
+
36
+ subprocess.check_call(['cmake', ext.sourcedir] + cmake_args, cwd=self.build_temp)
37
+ subprocess.check_call(['cmake', '--build', '.', '-j'] + build_args, cwd=self.build_temp)
38
+
39
+ # Copy the built library to the package source directory as well
40
+ src_package_dir = os.path.join(ext.sourcedir, 'src', 'jax_mlx')
41
+ built_lib = os.path.join(extdir, 'libmlx_pjrt_plugin.dylib')
42
+ if os.path.exists(built_lib) and os.path.isdir(src_package_dir):
43
+ shutil.copy2(built_lib, src_package_dir)
44
+
45
+ setup(
46
+ name='jax-mlx-plugin',
47
+ version='0.0.1',
48
+ author='Thomas Summe',
49
+ description='JAX PJRT plugin for Apple Silicon using MLX',
50
+ long_description=open('README.md').read(),
51
+ long_description_content_type='text/markdown',
52
+ url='https://github.com/tsumme1/jax-mlx',
53
+ packages=find_packages(where='src'),
54
+ package_dir={'': 'src'},
55
+ ext_modules=[CMakeExtension('jax_mlx.mlx_pjrt_plugin')],
56
+ cmdclass=dict(build_ext=CMakeBuild),
57
+ package_data={'jax_mlx': ['*.dylib', '*.so']},
58
+ include_package_data=True,
59
+ zip_safe=False,
60
+ python_requires='>=3.11',
61
+ install_requires=[
62
+ 'jax>=0.5.0',
63
+ 'jaxlib>=0.5.0',
64
+ 'mlx',
65
+ ],
66
+ entry_points={
67
+ "jax_plugins": [
68
+ "mlx_plugin = jax_mlx.plugin",
69
+ ],
70
+ },
71
+ classifiers=[
72
+ 'Development Status :: 3 - Alpha',
73
+ 'Intended Audience :: Developers',
74
+ 'Intended Audience :: Science/Research',
75
+ 'License :: OSI Approved :: MIT License',
76
+ 'Operating System :: MacOS :: MacOS X',
77
+ 'Programming Language :: Python :: 3',
78
+ 'Programming Language :: Python :: 3.11',
79
+ 'Programming Language :: Python :: 3.12',
80
+ 'Programming Language :: Python :: 3.13',
81
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
82
+ ],
83
+ )
@@ -0,0 +1,6 @@
1
+ # jax-mlx: JAX PJRT plugin for Apple Silicon using MLX
2
+ #
3
+ # This package provides MLX acceleration for JAX via the PJRT plugin interface.
4
+ # Registration is handled automatically by plugin.py via JAX's plugin mechanism.
5
+
6
+ __version__ = "0.0.1"