mlx-cluster 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_cluster-0.0.1/CMakeLists.txt +66 -0
- mlx_cluster-0.0.1/LICENSE +21 -0
- mlx_cluster-0.0.1/MANIFEST.in +1 -0
- mlx_cluster-0.0.1/PKG-INFO +74 -0
- mlx_cluster-0.0.1/README.md +54 -0
- mlx_cluster-0.0.1/bindings.cpp +65 -0
- mlx_cluster-0.0.1/mlx_cluster/__init__.py +4 -0
- mlx_cluster-0.0.1/mlx_cluster/_ext.cpython-311-darwin.so +0 -0
- mlx_cluster-0.0.1/mlx_cluster/libmlx_cluster.dylib +0 -0
- mlx_cluster-0.0.1/mlx_cluster/mlx_cluster.metallib +0 -0
- mlx_cluster-0.0.1/mlx_cluster.egg-info/PKG-INFO +74 -0
- mlx_cluster-0.0.1/mlx_cluster.egg-info/SOURCES.txt +24 -0
- mlx_cluster-0.0.1/mlx_cluster.egg-info/dependency_links.txt +1 -0
- mlx_cluster-0.0.1/mlx_cluster.egg-info/not-zip-safe +1 -0
- mlx_cluster-0.0.1/mlx_cluster.egg-info/requires.txt +6 -0
- mlx_cluster-0.0.1/mlx_cluster.egg-info/top_level.txt +1 -0
- mlx_cluster-0.0.1/pyproject.toml +33 -0
- mlx_cluster-0.0.1/random_walks/BiasedRandomWalk.cpp +194 -0
- mlx_cluster-0.0.1/random_walks/BiasedRandomWalk.h +66 -0
- mlx_cluster-0.0.1/random_walks/RandomWalk.cpp +147 -0
- mlx_cluster-0.0.1/random_walks/RandomWalk.h +63 -0
- mlx_cluster-0.0.1/random_walks/random_walk.metal +35 -0
- mlx_cluster-0.0.1/setup.cfg +4 -0
- mlx_cluster-0.0.1/setup.py +16 -0
- mlx_cluster-0.0.1/tests/test_random_walk.py +58 -0
- mlx_cluster-0.0.1/tests/test_rejection_sampling.py +49 -0
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
cmake_minimum_required(VERSION 3.27)
|
|
2
|
+
project(_ext LANGUAGES CXX)
|
|
3
|
+
|
|
4
|
+
# ----- Setup required -----
|
|
5
|
+
set(CMAKE_CXX_STANDARD 17)
|
|
6
|
+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
|
7
|
+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
|
8
|
+
|
|
9
|
+
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
|
10
|
+
|
|
11
|
+
# ----- Dependencies required ----
|
|
12
|
+
find_package(fmt REQUIRED)
|
|
13
|
+
find_package(MLX CONFIG REQUIRED)
|
|
14
|
+
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
|
15
|
+
execute_process(
|
|
16
|
+
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
|
17
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
|
|
18
|
+
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
|
19
|
+
find_package(nanobind CONFIG REQUIRED)
|
|
20
|
+
|
|
21
|
+
# ------ Adding extensions to the library -----
|
|
22
|
+
|
|
23
|
+
# Add library
|
|
24
|
+
add_library(mlx_cluster)
|
|
25
|
+
|
|
26
|
+
target_sources(mlx_cluster
|
|
27
|
+
PUBLIC
|
|
28
|
+
${CMAKE_CURRENT_LIST_DIR}/random_walks/RandomWalk.cpp
|
|
29
|
+
${CMAKE_CURRENT_LIST_DIR}/random_walks/BiasedRandomWalk.cpp
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
target_include_directories(mlx_cluster
|
|
33
|
+
PUBLIC
|
|
34
|
+
${CMAKE_CURRENT_LIST_DIR})
|
|
35
|
+
|
|
36
|
+
target_link_libraries(mlx_cluster PUBLIC mlx)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
if(MLX_BUILD_METAL)
|
|
40
|
+
mlx_build_metallib(
|
|
41
|
+
TARGET mlx_cluster_metallib
|
|
42
|
+
TITLE mlx_cluster
|
|
43
|
+
SOURCES ${CMAKE_CURRENT_LIST_DIR}/random_walks/random_walk.metal
|
|
44
|
+
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
|
|
45
|
+
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
add_dependencies(
|
|
49
|
+
mlx_cluster
|
|
50
|
+
mlx_cluster_metallib
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
endif()
|
|
54
|
+
# ----- Nanobind module -----
|
|
55
|
+
nanobind_add_module(
|
|
56
|
+
_ext
|
|
57
|
+
NB_STATIC STABLE_ABI LTO NOMINSIZE
|
|
58
|
+
NB_DOMAIN mlx
|
|
59
|
+
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
target_link_libraries(_ext PRIVATE mlx_cluster)
|
|
63
|
+
|
|
64
|
+
if(BUILD_SHARED_LIBS)
|
|
65
|
+
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
|
|
66
|
+
endif()
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 vinayhpandya
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
include CMakeLists.txt README.md bindings.cpp LICENSE random_walks/*.*
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: mlx_cluster
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: C++ and Metal extensions for MLX CTC Loss
|
|
5
|
+
Author-email: Vinay Pandya <vinayharshadpandya27@gmail.com>
|
|
6
|
+
Project-URL: Homepage, https://github.com/vinayhpandya/mlx_cluster
|
|
7
|
+
Project-URL: Issues, https://github.com/vinayhpandya/mlx_cluster/Issues
|
|
8
|
+
Classifier: Development Status :: 3 - Alpha
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Programming Language :: C++
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Requires-Python: >=3.8
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
License-File: LICENSE
|
|
16
|
+
Provides-Extra: dev
|
|
17
|
+
Provides-Extra: test
|
|
18
|
+
Requires-Dist: torch_geometric; extra == "test"
|
|
19
|
+
Requires-Dist: pytest; extra == "test"
|
|
20
|
+
|
|
21
|
+
# mlx_cluster
|
|
22
|
+
|
|
23
|
+
A C++ extension for generating ramdom walks for Homogeneous graphs using mlx
|
|
24
|
+
|
|
25
|
+
## Installation
|
|
26
|
+
|
|
27
|
+
To install the necessary dependencies:
|
|
28
|
+
|
|
29
|
+
Clone the repositories:
|
|
30
|
+
```bash
|
|
31
|
+
git clone https://github.com/vinayhpandya/mlx_cluster.git
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
After cloning the repository install library using
|
|
35
|
+
|
|
36
|
+
```bash
|
|
37
|
+
python setup.py build_ext -j8 --inplace
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
for testing purposes you need to have `mlx-graphs` installed
|
|
41
|
+
|
|
42
|
+
## Usage
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
```
|
|
46
|
+
from mlx_graphs.utils.sorting import sort_edge_index
|
|
47
|
+
from mlx_graphs.loaders import Dataloader
|
|
48
|
+
from mlx_graphs_extension import random_walk
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
|
|
52
|
+
start = mx.arange(0, 1000)
|
|
53
|
+
start_time = time.time()
|
|
54
|
+
edge_index = cora_dataset.graphs[0].edge_index
|
|
55
|
+
num_nodes = cora_dataset.graphs[0].num_nodes
|
|
56
|
+
sorted_edge_index = sort_edge_index(edge_index=edge_index)
|
|
57
|
+
row_mlx = sorted_edge_index[0][0]
|
|
58
|
+
col_mlx = sorted_edge_index[0][1]
|
|
59
|
+
unique_vals, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
|
|
60
|
+
cum_sum_mlx = counts_mlx.cumsum()
|
|
61
|
+
rand = mx.random.uniform(shape=[start.shape[0], 100])
|
|
62
|
+
row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
|
|
63
|
+
random_walk(row_ptr_mlx, col_mlx, start, rand, 1000, stream = mx.gpu)
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
## TODO
|
|
67
|
+
|
|
68
|
+
- [x] Add metal shaders to optimize the code
|
|
69
|
+
- [ ] Benchmark random walk against different frameworks
|
|
70
|
+
- [ ] Add more algorithms
|
|
71
|
+
|
|
72
|
+
## Credits:
|
|
73
|
+
|
|
74
|
+
torch_cluster random walk implementation : [random_walk](https://github.com/rusty1s/pytorch_cluster/blob/master/csrc/cpu/rw_cpu.cpp)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# mlx_cluster
|
|
2
|
+
|
|
3
|
+
A C++ extension for generating ramdom walks for Homogeneous graphs using mlx
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
To install the necessary dependencies:
|
|
8
|
+
|
|
9
|
+
Clone the repositories:
|
|
10
|
+
```bash
|
|
11
|
+
git clone https://github.com/vinayhpandya/mlx_cluster.git
|
|
12
|
+
```
|
|
13
|
+
|
|
14
|
+
After cloning the repository install library using
|
|
15
|
+
|
|
16
|
+
```bash
|
|
17
|
+
python setup.py build_ext -j8 --inplace
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
for testing purposes you need to have `mlx-graphs` installed
|
|
21
|
+
|
|
22
|
+
## Usage
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
```
|
|
26
|
+
from mlx_graphs.utils.sorting import sort_edge_index
|
|
27
|
+
from mlx_graphs.loaders import Dataloader
|
|
28
|
+
from mlx_graphs_extension import random_walk
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
|
|
32
|
+
start = mx.arange(0, 1000)
|
|
33
|
+
start_time = time.time()
|
|
34
|
+
edge_index = cora_dataset.graphs[0].edge_index
|
|
35
|
+
num_nodes = cora_dataset.graphs[0].num_nodes
|
|
36
|
+
sorted_edge_index = sort_edge_index(edge_index=edge_index)
|
|
37
|
+
row_mlx = sorted_edge_index[0][0]
|
|
38
|
+
col_mlx = sorted_edge_index[0][1]
|
|
39
|
+
unique_vals, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
|
|
40
|
+
cum_sum_mlx = counts_mlx.cumsum()
|
|
41
|
+
rand = mx.random.uniform(shape=[start.shape[0], 100])
|
|
42
|
+
row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
|
|
43
|
+
random_walk(row_ptr_mlx, col_mlx, start, rand, 1000, stream = mx.gpu)
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
## TODO
|
|
47
|
+
|
|
48
|
+
- [x] Add metal shaders to optimize the code
|
|
49
|
+
- [ ] Benchmark random walk against different frameworks
|
|
50
|
+
- [ ] Add more algorithms
|
|
51
|
+
|
|
52
|
+
## Credits:
|
|
53
|
+
|
|
54
|
+
torch_cluster random walk implementation : [random_walk](https://github.com/rusty1s/pytorch_cluster/blob/master/csrc/cpu/rw_cpu.cpp)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
#include <nanobind/nanobind.h>
|
|
2
|
+
#include <nanobind/stl/variant.h>
|
|
3
|
+
#include <random_walks/RandomWalk.h>
|
|
4
|
+
#include <random_walks/BiasedRandomWalk.h>
|
|
5
|
+
|
|
6
|
+
namespace nb = nanobind;
|
|
7
|
+
using namespace nb::literals;
|
|
8
|
+
using namespace mlx::core;
|
|
9
|
+
|
|
10
|
+
NB_MODULE(_ext, m){
|
|
11
|
+
|
|
12
|
+
m.def(
|
|
13
|
+
"random_walk",
|
|
14
|
+
&random_walk,
|
|
15
|
+
"rowptr"_a,
|
|
16
|
+
"col"_a,
|
|
17
|
+
"start"_a,
|
|
18
|
+
"rand"_a,
|
|
19
|
+
"walk_length"_a,
|
|
20
|
+
nb::kw_only(),
|
|
21
|
+
"stream"_a = nb::none(),
|
|
22
|
+
R"(
|
|
23
|
+
uniformly sample a graph
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
rowptr (array): rowptr of graph in csr format.
|
|
28
|
+
col (array): edges in csr format.
|
|
29
|
+
walk_length (int) : walk length of random graph
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
array: consisting of nodes visited on random walk
|
|
33
|
+
)");
|
|
34
|
+
|
|
35
|
+
m.def(
|
|
36
|
+
"rejection_sampling",
|
|
37
|
+
&rejection_sampling,
|
|
38
|
+
"rowptr"_a,
|
|
39
|
+
"col"_a,
|
|
40
|
+
"start"_a,
|
|
41
|
+
"walk_length"_a,
|
|
42
|
+
"p"_a,
|
|
43
|
+
"q"_a,
|
|
44
|
+
nb::kw_only(),
|
|
45
|
+
"stream"_a = nb::none(),
|
|
46
|
+
R"(
|
|
47
|
+
Sample nodes from the graph by sampling neighbors based
|
|
48
|
+
on probablity p and q
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
rowptr (array): rowptr of graph in csr format.
|
|
53
|
+
col (array): edges in csr format.
|
|
54
|
+
start (array): starting node of graph from which
|
|
55
|
+
biased sampling will be performed.
|
|
56
|
+
walk_length (int) : walk length of random graph
|
|
57
|
+
p : Likelihood of immediately revisiting a node in the walk.
|
|
58
|
+
q : Control parameter to interpolate between
|
|
59
|
+
breadth-first strategy and depth-first strategy
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
array: consisting of nodes visited on random walk
|
|
63
|
+
)");
|
|
64
|
+
}
|
|
65
|
+
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: mlx_cluster
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: C++ and Metal extensions for MLX CTC Loss
|
|
5
|
+
Author-email: Vinay Pandya <vinayharshadpandya27@gmail.com>
|
|
6
|
+
Project-URL: Homepage, https://github.com/vinayhpandya/mlx_cluster
|
|
7
|
+
Project-URL: Issues, https://github.com/vinayhpandya/mlx_cluster/Issues
|
|
8
|
+
Classifier: Development Status :: 3 - Alpha
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Programming Language :: C++
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Requires-Python: >=3.8
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
License-File: LICENSE
|
|
16
|
+
Provides-Extra: dev
|
|
17
|
+
Provides-Extra: test
|
|
18
|
+
Requires-Dist: torch_geometric; extra == "test"
|
|
19
|
+
Requires-Dist: pytest; extra == "test"
|
|
20
|
+
|
|
21
|
+
# mlx_cluster
|
|
22
|
+
|
|
23
|
+
A C++ extension for generating ramdom walks for Homogeneous graphs using mlx
|
|
24
|
+
|
|
25
|
+
## Installation
|
|
26
|
+
|
|
27
|
+
To install the necessary dependencies:
|
|
28
|
+
|
|
29
|
+
Clone the repositories:
|
|
30
|
+
```bash
|
|
31
|
+
git clone https://github.com/vinayhpandya/mlx_cluster.git
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
After cloning the repository install library using
|
|
35
|
+
|
|
36
|
+
```bash
|
|
37
|
+
python setup.py build_ext -j8 --inplace
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
for testing purposes you need to have `mlx-graphs` installed
|
|
41
|
+
|
|
42
|
+
## Usage
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
```
|
|
46
|
+
from mlx_graphs.utils.sorting import sort_edge_index
|
|
47
|
+
from mlx_graphs.loaders import Dataloader
|
|
48
|
+
from mlx_graphs_extension import random_walk
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
|
|
52
|
+
start = mx.arange(0, 1000)
|
|
53
|
+
start_time = time.time()
|
|
54
|
+
edge_index = cora_dataset.graphs[0].edge_index
|
|
55
|
+
num_nodes = cora_dataset.graphs[0].num_nodes
|
|
56
|
+
sorted_edge_index = sort_edge_index(edge_index=edge_index)
|
|
57
|
+
row_mlx = sorted_edge_index[0][0]
|
|
58
|
+
col_mlx = sorted_edge_index[0][1]
|
|
59
|
+
unique_vals, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
|
|
60
|
+
cum_sum_mlx = counts_mlx.cumsum()
|
|
61
|
+
rand = mx.random.uniform(shape=[start.shape[0], 100])
|
|
62
|
+
row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
|
|
63
|
+
random_walk(row_ptr_mlx, col_mlx, start, rand, 1000, stream = mx.gpu)
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
## TODO
|
|
67
|
+
|
|
68
|
+
- [x] Add metal shaders to optimize the code
|
|
69
|
+
- [ ] Benchmark random walk against different frameworks
|
|
70
|
+
- [ ] Add more algorithms
|
|
71
|
+
|
|
72
|
+
## Credits:
|
|
73
|
+
|
|
74
|
+
torch_cluster random walk implementation : [random_walk](https://github.com/rusty1s/pytorch_cluster/blob/master/csrc/cpu/rw_cpu.cpp)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
CMakeLists.txt
|
|
2
|
+
LICENSE
|
|
3
|
+
MANIFEST.in
|
|
4
|
+
README.md
|
|
5
|
+
bindings.cpp
|
|
6
|
+
pyproject.toml
|
|
7
|
+
setup.py
|
|
8
|
+
mlx_cluster/__init__.py
|
|
9
|
+
mlx_cluster/_ext.cpython-311-darwin.so
|
|
10
|
+
mlx_cluster/libmlx_cluster.dylib
|
|
11
|
+
mlx_cluster/mlx_cluster.metallib
|
|
12
|
+
mlx_cluster.egg-info/PKG-INFO
|
|
13
|
+
mlx_cluster.egg-info/SOURCES.txt
|
|
14
|
+
mlx_cluster.egg-info/dependency_links.txt
|
|
15
|
+
mlx_cluster.egg-info/not-zip-safe
|
|
16
|
+
mlx_cluster.egg-info/requires.txt
|
|
17
|
+
mlx_cluster.egg-info/top_level.txt
|
|
18
|
+
random_walks/BiasedRandomWalk.cpp
|
|
19
|
+
random_walks/BiasedRandomWalk.h
|
|
20
|
+
random_walks/RandomWalk.cpp
|
|
21
|
+
random_walks/RandomWalk.h
|
|
22
|
+
random_walks/random_walk.metal
|
|
23
|
+
tests/test_random_walk.py
|
|
24
|
+
tests/test_rejection_sampling.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
mlx_cluster
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "mlx_cluster"
|
|
3
|
+
version = "0.0.1"
|
|
4
|
+
authors = [
|
|
5
|
+
{ name = "Vinay Pandya", email = "vinayharshadpandya27@gmail.com" },
|
|
6
|
+
]
|
|
7
|
+
description = "C++ and Metal extensions for MLX CTC Loss"
|
|
8
|
+
readme = "README.md"
|
|
9
|
+
requires-python = ">=3.8"
|
|
10
|
+
classifiers = [
|
|
11
|
+
"Development Status :: 3 - Alpha",
|
|
12
|
+
"Programming Language :: Python :: 3",
|
|
13
|
+
"Programming Language :: C++",
|
|
14
|
+
"License :: OSI Approved :: MIT License",
|
|
15
|
+
"Operating System :: OS Independent",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
[project.optional-dependencies]
|
|
19
|
+
dev = []
|
|
20
|
+
|
|
21
|
+
[project.urls]
|
|
22
|
+
Homepage = "https://github.com/vinayhpandya/mlx_cluster"
|
|
23
|
+
Issues = "https://github.com/vinayhpandya/mlx_cluster/Issues"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
[build-system]
|
|
27
|
+
requires = [
|
|
28
|
+
"setuptools>=42",
|
|
29
|
+
"cmake>=3.24",
|
|
30
|
+
"mlx>=0.9.0",
|
|
31
|
+
"nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
|
|
32
|
+
]
|
|
33
|
+
build-backend = "setuptools.build_meta"
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
#include <cassert>
|
|
2
|
+
#include <iostream>
|
|
3
|
+
#include <sstream>
|
|
4
|
+
#include <random>
|
|
5
|
+
|
|
6
|
+
#include "mlx/backend/common/copy.h"
|
|
7
|
+
#include "mlx/backend/common/utils.h"
|
|
8
|
+
#include "mlx/utils.h"
|
|
9
|
+
#include "mlx/random.h"
|
|
10
|
+
#include "mlx/ops.h"
|
|
11
|
+
#include "mlx/array.h"
|
|
12
|
+
#ifdef _METAL_
|
|
13
|
+
#include "mlx/backend/metal/device.h"
|
|
14
|
+
#include "mlx/backend/metal/utils.h"
|
|
15
|
+
#endif
|
|
16
|
+
#include "random_walks/BiasedRandomWalk.h"
|
|
17
|
+
|
|
18
|
+
namespace mlx::core {
|
|
19
|
+
|
|
20
|
+
bool inline is_neighbor(const int64_t *rowptr, const int64_t *col, int64_t v,
|
|
21
|
+
int64_t w) {
|
|
22
|
+
int64_t row_start = rowptr[v], row_end = rowptr[v + 1];
|
|
23
|
+
for (auto i = row_start; i < row_end; i++) {
|
|
24
|
+
if (col[i] == w)
|
|
25
|
+
return true;
|
|
26
|
+
}
|
|
27
|
+
return false;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
void BiasedRandomWalk::eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
|
31
|
+
auto& rowptr = inputs[0];
|
|
32
|
+
auto& col = inputs[1];
|
|
33
|
+
auto& start = inputs[2];
|
|
34
|
+
auto& rand = inputs[3];
|
|
35
|
+
int numel = start.size();
|
|
36
|
+
|
|
37
|
+
// Initialize outputs
|
|
38
|
+
assert(outputs.size() == 2);
|
|
39
|
+
// Allocate memory for outputs if not already allocated
|
|
40
|
+
outputs[0].set_data(allocator::malloc_or_wait(numel*(walk_length_+1)*sizeof(int64_t)));
|
|
41
|
+
outputs[1].set_data(allocator::malloc_or_wait(numel*walk_length_*sizeof(int64_t)));
|
|
42
|
+
auto& n_out = outputs[0];
|
|
43
|
+
auto& e_out = outputs[1];
|
|
44
|
+
|
|
45
|
+
auto* n_out_ptr = n_out.data<int64_t>();
|
|
46
|
+
auto* e_out_ptr = e_out.data<int64_t>();
|
|
47
|
+
auto* start_values = start.data<int64_t>();
|
|
48
|
+
auto* row_ptr = rowptr.data<int64_t>();
|
|
49
|
+
auto* col_values = col.data<int64_t>();
|
|
50
|
+
auto* rand_values = rand.data<float>();
|
|
51
|
+
|
|
52
|
+
double max_prob = fmax(fmax(1. / p_, 1.), 1. / q_);
|
|
53
|
+
double prob_0 = 1. / p_ / max_prob;
|
|
54
|
+
double prob_1 = 1. / max_prob;
|
|
55
|
+
double prob_2 = 1. / q_ / max_prob;
|
|
56
|
+
|
|
57
|
+
for (int64_t n = 0; n < numel; n++) {
|
|
58
|
+
int64_t t = start_values[n], v, x, e_cur, row_start, row_end;
|
|
59
|
+
n_out_ptr[n * (walk_length_ + 1)] = t;
|
|
60
|
+
row_start = row_ptr[t], row_end = row_ptr[t + 1];
|
|
61
|
+
if (row_end - row_start == 0) {
|
|
62
|
+
e_cur = -1;
|
|
63
|
+
v = t;
|
|
64
|
+
} else {
|
|
65
|
+
e_cur = row_start + (std::rand() % (row_end - row_start));
|
|
66
|
+
v = col_values[e_cur];
|
|
67
|
+
}
|
|
68
|
+
n_out_ptr[n * (walk_length_ + 1) + 1] = v;
|
|
69
|
+
e_out_ptr[n * walk_length_] = e_cur;
|
|
70
|
+
for (auto l = 1; l < walk_length_; l++) {
|
|
71
|
+
row_start = row_ptr[v], row_end = row_ptr[v + 1];
|
|
72
|
+
|
|
73
|
+
if (row_end - row_start == 0) {
|
|
74
|
+
e_cur = -1;
|
|
75
|
+
x = v;
|
|
76
|
+
} else if (row_end - row_start == 1) {
|
|
77
|
+
e_cur = row_start;
|
|
78
|
+
x = col_values[e_cur];
|
|
79
|
+
} else {
|
|
80
|
+
while (true) {
|
|
81
|
+
e_cur = row_start + (std::rand() % (row_end - row_start));
|
|
82
|
+
x = col_values[e_cur];
|
|
83
|
+
|
|
84
|
+
auto r = ((double)std::rand() / (RAND_MAX)); // [0, 1)
|
|
85
|
+
|
|
86
|
+
if (x == t && r < prob_0)
|
|
87
|
+
break;
|
|
88
|
+
else if (is_neighbor(row_ptr, col_values, x, t) && r < prob_1)
|
|
89
|
+
break;
|
|
90
|
+
else if (r < prob_2)
|
|
91
|
+
break;
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
n_out_ptr[n * (walk_length_ + 1) + (l + 1)] = x;
|
|
96
|
+
e_out_ptr[n * walk_length_ + l] = e_cur;
|
|
97
|
+
t = v;
|
|
98
|
+
v = x;
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
};
|
|
103
|
+
|
|
104
|
+
std::vector<array> BiasedRandomWalk::jvp(
|
|
105
|
+
const std::vector<array>& primals,
|
|
106
|
+
const std::vector<array>& tangents,
|
|
107
|
+
const std::vector<int>& argnums)
|
|
108
|
+
{
|
|
109
|
+
// Random walk is not differentiable, so we return zero tangents
|
|
110
|
+
throw std::runtime_error("Biased random walk has no jvp implementation.");
|
|
111
|
+
}
|
|
112
|
+
// #ifdef _METAL_
|
|
113
|
+
// void BiasedRandomWalk::eval_gpu(
|
|
114
|
+
// const std::vector<array>& inputs,
|
|
115
|
+
// std::vector<array>& outputs
|
|
116
|
+
// ){
|
|
117
|
+
// auto& rowptr = inputs[0];
|
|
118
|
+
// auto& col = inputs[1];
|
|
119
|
+
// auto& start = inputs[2];
|
|
120
|
+
// auto& rand = inputs[3];
|
|
121
|
+
// int numel = start.size();
|
|
122
|
+
|
|
123
|
+
// assert(outputs.size() == 2);
|
|
124
|
+
// outputs[0].set_data(allocator::malloc_or_wait(numel * (walk_length_ + 1) * sizeof(int64_t)));
|
|
125
|
+
// outputs[1].set_data(allocator::malloc_or_wait(numel * walk_length_ * sizeof(int64_t)));
|
|
126
|
+
// std::cout<<"after setting data"<<std::endl;
|
|
127
|
+
// auto& s = stream();
|
|
128
|
+
// auto& d = metal::device(s.device);
|
|
129
|
+
|
|
130
|
+
// d.register_library("mlx_cluster", metal::get_colocated_mtllib_path);
|
|
131
|
+
// auto kernel = d.get_kernel("random_walk", "mlx_cluster");
|
|
132
|
+
|
|
133
|
+
// auto& compute_encoder = d.get_command_encoder(s.index);
|
|
134
|
+
// compute_encoder->setComputePipelineState(kernel);
|
|
135
|
+
|
|
136
|
+
// compute_encoder.set_input_array(rowptr, 0);
|
|
137
|
+
// compute_encoder.set_input_array(col, 1);
|
|
138
|
+
// compute_encoder.set_input_array(start, 2);
|
|
139
|
+
// compute_encoder.set_input_array(rand, 3);
|
|
140
|
+
// compute_encoder.set_output_array(outputs[0], 4);
|
|
141
|
+
// compute_encoder.set_output_array(outputs[1], 5);
|
|
142
|
+
// compute_encoder->setBytes(&walk_length_, sizeof(int32), 6);
|
|
143
|
+
|
|
144
|
+
// MTL::Size grid_size = MTL::Size(numel, 1, 1);
|
|
145
|
+
// MTL::Size thread_group_size = MTL::Size(kernel->maxTotalThreadsPerThreadgroup(), 1, 1);
|
|
146
|
+
|
|
147
|
+
// compute_encoder.dispatchThreads(grid_size, thread_group_size);
|
|
148
|
+
// }
|
|
149
|
+
// #endif
|
|
150
|
+
void BiasedRandomWalk::eval_gpu(
|
|
151
|
+
const std::vector<array>& inputs, std::vector<array>& outputs
|
|
152
|
+
)
|
|
153
|
+
{
|
|
154
|
+
throw std::runtime_error("Random walk has no GPU implementation.");
|
|
155
|
+
}
|
|
156
|
+
std::vector<array> BiasedRandomWalk::vjp(
|
|
157
|
+
const std::vector<array>& primals,
|
|
158
|
+
const std::vector<array>& cotangents,
|
|
159
|
+
const std::vector<int>& argnums,
|
|
160
|
+
const std::vector<array>& outputs)
|
|
161
|
+
{
|
|
162
|
+
// Random walk is not differentiable, so we return zero gradients
|
|
163
|
+
throw std::runtime_error("Random walk has no JVP implementation.");
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
std::pair<std::vector<array>, std::vector<int>> BiasedRandomWalk::vmap(
|
|
167
|
+
const std::vector<array>& inputs,
|
|
168
|
+
const std::vector<int>& axes)
|
|
169
|
+
{
|
|
170
|
+
throw std::runtime_error("vmap not implemented for biasedRandomWalk");
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
bool BiasedRandomWalk::is_equivalent(const Primitive& other) const
|
|
174
|
+
{
|
|
175
|
+
throw std::runtime_error("biased Random walk has no GPU implementation.");
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
std::vector<std::vector<int>> BiasedRandomWalk::output_shapes(const std::vector<array>& inputs)
|
|
179
|
+
{
|
|
180
|
+
throw std::runtime_error("biased Random walk has no GPU implementation.");
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
array rejection_sampling(const array& rowptr, const array& col, const array& start, int walk_length, const double p,
|
|
184
|
+
const double q, StreamOrDevice s)
|
|
185
|
+
{
|
|
186
|
+
int nodes = start.size();
|
|
187
|
+
auto primitive = std::make_shared<BiasedRandomWalk>(to_stream(s), walk_length, p, q);
|
|
188
|
+
return array::make_arrays({{nodes,walk_length+1},{nodes, walk_length}},
|
|
189
|
+
{rowptr.dtype(), rowptr.dtype()},
|
|
190
|
+
primitive,
|
|
191
|
+
{rowptr, col, start}
|
|
192
|
+
)[0];
|
|
193
|
+
}
|
|
194
|
+
}
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <mlx/array.h>
|
|
4
|
+
#include <mlx/ops.h>
|
|
5
|
+
#include <mlx/primitives.h>
|
|
6
|
+
|
|
7
|
+
namespace mlx::core{
|
|
8
|
+
|
|
9
|
+
class BiasedRandomWalk : public Primitive {
|
|
10
|
+
public:
|
|
11
|
+
BiasedRandomWalk(Stream stream, int walk_length, double p, double q)
|
|
12
|
+
: Primitive(stream), walk_length_(walk_length), p_(p), q_(q) {}
|
|
13
|
+
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
14
|
+
override;
|
|
15
|
+
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
16
|
+
override;
|
|
17
|
+
|
|
18
|
+
/** The Jacobian-vector product. */
|
|
19
|
+
std::vector<array> jvp(
|
|
20
|
+
const std::vector<array>& primals,
|
|
21
|
+
const std::vector<array>& tangents,
|
|
22
|
+
const std::vector<int>& argnums) override;
|
|
23
|
+
|
|
24
|
+
/** The vector-Jacobian product. */
|
|
25
|
+
std::vector<array> vjp(
|
|
26
|
+
const std::vector<array>& primals,
|
|
27
|
+
const std::vector<array>& cotangents,
|
|
28
|
+
const std::vector<int>& argnums,
|
|
29
|
+
const std::vector<array>& outputs) override;
|
|
30
|
+
|
|
31
|
+
/**
|
|
32
|
+
* The primitive must know how to vectorize itself across
|
|
33
|
+
* the given axes. The output is a pair containing the array
|
|
34
|
+
* representing the vectorized computation and the axis which
|
|
35
|
+
* corresponds to the output vectorized dimension.
|
|
36
|
+
*/
|
|
37
|
+
std::pair<std::vector<array>, std::vector<int>> vmap(
|
|
38
|
+
const std::vector<array>& inputs,
|
|
39
|
+
const std::vector<int>& axes) override;
|
|
40
|
+
|
|
41
|
+
/** Print the primitive. */
|
|
42
|
+
void print(std::ostream& os) override {
|
|
43
|
+
os << "biased random walk implementation";
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
/** Equivalence check **/
|
|
47
|
+
bool is_equivalent(const Primitive& other) const override;
|
|
48
|
+
|
|
49
|
+
std::vector<std::vector<int>> output_shapes(const std::vector<array>& inputs) override;
|
|
50
|
+
|
|
51
|
+
private:
|
|
52
|
+
int walk_length_;
|
|
53
|
+
double p_;
|
|
54
|
+
double q_;
|
|
55
|
+
|
|
56
|
+
};
|
|
57
|
+
|
|
58
|
+
array rejection_sampling(const array& rowptr,
|
|
59
|
+
const array& col,
|
|
60
|
+
const array& start,
|
|
61
|
+
int walk_length,
|
|
62
|
+
const double p,
|
|
63
|
+
const double q,
|
|
64
|
+
StreamOrDevice s = {});
|
|
65
|
+
|
|
66
|
+
};
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
#include <cassert>
|
|
2
|
+
#include <iostream>
|
|
3
|
+
#include <sstream>
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/common/copy.h"
|
|
6
|
+
#include "mlx/backend/common/utils.h"
|
|
7
|
+
#include "mlx/utils.h"
|
|
8
|
+
#include "mlx/random.h"
|
|
9
|
+
#include "mlx/ops.h"
|
|
10
|
+
#include "mlx/array.h"
|
|
11
|
+
#ifdef _METAL_
|
|
12
|
+
#include "mlx/backend/metal/device.h"
|
|
13
|
+
#include "mlx/backend/metal/utils.h"
|
|
14
|
+
#endif
|
|
15
|
+
#include "random_walks/RandomWalk.h"
|
|
16
|
+
|
|
17
|
+
namespace mlx::core {
|
|
18
|
+
void RandomWalk::eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
|
19
|
+
auto& rowptr = inputs[0];
|
|
20
|
+
auto& col = inputs[1];
|
|
21
|
+
auto& start = inputs[2];
|
|
22
|
+
auto& rand = inputs[3];
|
|
23
|
+
int numel = start.size();
|
|
24
|
+
|
|
25
|
+
// Initialize outputs
|
|
26
|
+
assert(outputs.size() == 2);
|
|
27
|
+
// Allocate memory for outputs if not already allocated
|
|
28
|
+
outputs[0].set_data(allocator::malloc_or_wait(numel*(walk_length_+1)*sizeof(int64_t)));
|
|
29
|
+
outputs[1].set_data(allocator::malloc_or_wait(numel*walk_length_*sizeof(int64_t)));
|
|
30
|
+
auto& n_out = outputs[0];
|
|
31
|
+
auto& e_out = outputs[1];
|
|
32
|
+
|
|
33
|
+
auto* n_out_ptr = n_out.data<int64_t>();
|
|
34
|
+
auto* e_out_ptr = e_out.data<int64_t>();
|
|
35
|
+
auto* start_values = start.data<int64_t>();
|
|
36
|
+
auto* row_ptr = rowptr.data<int64_t>();
|
|
37
|
+
auto* col_values = col.data<int64_t>();
|
|
38
|
+
auto* rand_values = rand.data<float>();
|
|
39
|
+
|
|
40
|
+
for (int64_t n = 0; n < numel; n++) {
|
|
41
|
+
int64_t n_cur = start_values[n];
|
|
42
|
+
n_out_ptr[n * (walk_length_ + 1)] = n_cur;
|
|
43
|
+
for (int l = 0; l < walk_length_; l++) {
|
|
44
|
+
int64_t row_start = row_ptr[n_cur];
|
|
45
|
+
int64_t row_end = row_ptr[n_cur+1];
|
|
46
|
+
int64_t e_cur;
|
|
47
|
+
if (row_end - row_start == 0) {
|
|
48
|
+
e_cur = -1;
|
|
49
|
+
} else {
|
|
50
|
+
float r = rand_values[n*walk_length_+l];
|
|
51
|
+
int64_t idx = static_cast<int64_t>(r * (row_end - row_start));
|
|
52
|
+
e_cur = row_start + idx;
|
|
53
|
+
n_cur = col_values[e_cur];
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
n_out_ptr[n * (walk_length_ + 1) + (l + 1)] = n_cur;
|
|
57
|
+
e_out_ptr[n * walk_length_ + l] = e_cur;
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
};
|
|
62
|
+
|
|
63
|
+
std::vector<array> RandomWalk::jvp(
|
|
64
|
+
const std::vector<array>& primals,
|
|
65
|
+
const std::vector<array>& tangents,
|
|
66
|
+
const std::vector<int>& argnums)
|
|
67
|
+
{
|
|
68
|
+
// Random walk is not differentiable, so we return zero tangents
|
|
69
|
+
throw std::runtime_error("Random walk has no GPU implementation.");
|
|
70
|
+
}
|
|
71
|
+
#ifdef _METAL_
|
|
72
|
+
void RandomWalk::eval_gpu(
|
|
73
|
+
const std::vector<array>& inputs,
|
|
74
|
+
std::vector<array>& outputs
|
|
75
|
+
){
|
|
76
|
+
auto& rowptr = inputs[0];
|
|
77
|
+
auto& col = inputs[1];
|
|
78
|
+
auto& start = inputs[2];
|
|
79
|
+
auto& rand = inputs[3];
|
|
80
|
+
int numel = start.size();
|
|
81
|
+
|
|
82
|
+
assert(outputs.size() == 2);
|
|
83
|
+
outputs[0].set_data(allocator::malloc_or_wait(numel * (walk_length_ + 1) * sizeof(int64_t)));
|
|
84
|
+
outputs[1].set_data(allocator::malloc_or_wait(numel * walk_length_ * sizeof(int64_t)));
|
|
85
|
+
std::cout<<"after setting data"<<std::endl;
|
|
86
|
+
auto& s = stream();
|
|
87
|
+
auto& d = metal::device(s.device);
|
|
88
|
+
|
|
89
|
+
d.register_library("mlx_cluster");
|
|
90
|
+
auto kernel = d.get_kernel("random_walk", "mlx_cluster");
|
|
91
|
+
|
|
92
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
93
|
+
compute_encoder->setComputePipelineState(kernel);
|
|
94
|
+
|
|
95
|
+
compute_encoder.set_input_array(rowptr, 0);
|
|
96
|
+
compute_encoder.set_input_array(col, 1);
|
|
97
|
+
compute_encoder.set_input_array(start, 2);
|
|
98
|
+
compute_encoder.set_input_array(rand, 3);
|
|
99
|
+
compute_encoder.set_output_array(outputs[0], 4);
|
|
100
|
+
compute_encoder.set_output_array(outputs[1], 5);
|
|
101
|
+
compute_encoder->setBytes(&walk_length_, sizeof(int32), 6);
|
|
102
|
+
|
|
103
|
+
MTL::Size grid_size = MTL::Size(numel, 1, 1);
|
|
104
|
+
MTL::Size thread_group_size = MTL::Size(kernel->maxTotalThreadsPerThreadgroup(), 1, 1);
|
|
105
|
+
|
|
106
|
+
compute_encoder.dispatchThreads(grid_size, thread_group_size);
|
|
107
|
+
}
|
|
108
|
+
#endif
|
|
109
|
+
|
|
110
|
+
std::vector<array> RandomWalk::vjp(
|
|
111
|
+
const std::vector<array>& primals,
|
|
112
|
+
const std::vector<array>& cotangents,
|
|
113
|
+
const std::vector<int>& argnums,
|
|
114
|
+
const std::vector<array>& outputs)
|
|
115
|
+
{
|
|
116
|
+
// Random walk is not differentiable, so we return zero gradients
|
|
117
|
+
throw std::runtime_error("Random walk has no GPU implementation.");
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
std::pair<std::vector<array>, std::vector<int>> RandomWalk::vmap(
|
|
121
|
+
const std::vector<array>& inputs,
|
|
122
|
+
const std::vector<int>& axes)
|
|
123
|
+
{
|
|
124
|
+
throw std::runtime_error("vmap not implemented for RandomWalk");
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
bool RandomWalk::is_equivalent(const Primitive& other) const
|
|
128
|
+
{
|
|
129
|
+
throw std::runtime_error("Random walk has no GPU implementation.");
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
std::vector<std::vector<int>> RandomWalk::output_shapes(const std::vector<array>& inputs)
|
|
133
|
+
{
|
|
134
|
+
throw std::runtime_error("Random walk has no GPU implementation.");
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
array random_walk(const array& rowptr, const array& col, const array& start, const array& rand, int walk_length, StreamOrDevice s)
|
|
138
|
+
{
|
|
139
|
+
int nodes = start.size();
|
|
140
|
+
auto primitive = std::make_shared<RandomWalk>(walk_length, to_stream(s));
|
|
141
|
+
return array::make_arrays({{nodes,walk_length+1},{nodes, walk_length}},
|
|
142
|
+
{start.dtype(), start.dtype()},
|
|
143
|
+
primitive,
|
|
144
|
+
{rowptr, col, start, rand}
|
|
145
|
+
)[0];
|
|
146
|
+
}
|
|
147
|
+
}
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <mlx/array.h>
|
|
4
|
+
#include <mlx/ops.h>
|
|
5
|
+
#include <mlx/primitives.h>
|
|
6
|
+
|
|
7
|
+
namespace mlx::core{
|
|
8
|
+
|
|
9
|
+
class RandomWalk : public Primitive {
|
|
10
|
+
public:
|
|
11
|
+
explicit RandomWalk(int walk_length, Stream stream):
|
|
12
|
+
Primitive(stream), walk_length_(walk_length) {};
|
|
13
|
+
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
14
|
+
override;
|
|
15
|
+
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
16
|
+
override;
|
|
17
|
+
|
|
18
|
+
/** The Jacobian-vector product. */
|
|
19
|
+
std::vector<array> jvp(
|
|
20
|
+
const std::vector<array>& primals,
|
|
21
|
+
const std::vector<array>& tangents,
|
|
22
|
+
const std::vector<int>& argnums) override;
|
|
23
|
+
|
|
24
|
+
/** The vector-Jacobian product. */
|
|
25
|
+
std::vector<array> vjp(
|
|
26
|
+
const std::vector<array>& primals,
|
|
27
|
+
const std::vector<array>& cotangents,
|
|
28
|
+
const std::vector<int>& argnums,
|
|
29
|
+
const std::vector<array>& outputs) override;
|
|
30
|
+
|
|
31
|
+
/**
|
|
32
|
+
* The primitive must know how to vectorize itself across
|
|
33
|
+
* the given axes. The output is a pair containing the array
|
|
34
|
+
* representing the vectorized computation and the axis which
|
|
35
|
+
* corresponds to the output vectorized dimension.
|
|
36
|
+
*/
|
|
37
|
+
std::pair<std::vector<array>, std::vector<int>> vmap(
|
|
38
|
+
const std::vector<array>& inputs,
|
|
39
|
+
const std::vector<int>& axes) override;
|
|
40
|
+
|
|
41
|
+
/** Print the primitive. */
|
|
42
|
+
void print(std::ostream& os) override {
|
|
43
|
+
os << "Random walk implementation";
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
/** Equivalence check **/
|
|
47
|
+
bool is_equivalent(const Primitive& other) const override;
|
|
48
|
+
|
|
49
|
+
std::vector<std::vector<int>> output_shapes(const std::vector<array>& inputs) override;
|
|
50
|
+
|
|
51
|
+
private:
|
|
52
|
+
int walk_length_;
|
|
53
|
+
|
|
54
|
+
};
|
|
55
|
+
|
|
56
|
+
array random_walk(const array& rowptr,
|
|
57
|
+
const array& col,
|
|
58
|
+
const array& start,
|
|
59
|
+
const array& rand,
|
|
60
|
+
int walk_length,
|
|
61
|
+
StreamOrDevice s = {});
|
|
62
|
+
|
|
63
|
+
};
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
#include <metal_stdlib>
|
|
2
|
+
#include "mlx/backend/metal/kernels/utils.h"
|
|
3
|
+
using namespace metal;
|
|
4
|
+
|
|
5
|
+
kernel void random_walk(
|
|
6
|
+
const device int64_t* rowptr [[buffer(0)]],
|
|
7
|
+
const device int64_t* col [[buffer(1)]],
|
|
8
|
+
const device int64_t* start [[buffer(2)]],
|
|
9
|
+
const device float* rand [[buffer(3)]],
|
|
10
|
+
device int64_t* n_out [[buffer(4)]],
|
|
11
|
+
device int64_t* e_out [[buffer(5)]],
|
|
12
|
+
constant int& walk_length [[buffer(6)]],
|
|
13
|
+
uint tid [[thread_position_in_grid]]
|
|
14
|
+
) {
|
|
15
|
+
int64_t n_cur = start[tid];
|
|
16
|
+
n_out[tid * (walk_length + 1)] = n_cur;
|
|
17
|
+
|
|
18
|
+
for (int l = 0; l < walk_length; l++) {
|
|
19
|
+
int64_t row_start = rowptr[n_cur];
|
|
20
|
+
int64_t row_end = rowptr[n_cur + 1];
|
|
21
|
+
int64_t e_cur;
|
|
22
|
+
|
|
23
|
+
if (row_end - row_start == 0) {
|
|
24
|
+
e_cur = -1;
|
|
25
|
+
} else {
|
|
26
|
+
float r = rand[tid * walk_length + l];
|
|
27
|
+
int64_t idx = static_cast<int64_t>(r * (row_end - row_start));
|
|
28
|
+
e_cur = row_start + idx;
|
|
29
|
+
n_cur = col[e_cur];
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
n_out[tid * (walk_length + 1) + (l + 1)] = n_cur;
|
|
33
|
+
e_out[tid * walk_length + l] = e_cur;
|
|
34
|
+
}
|
|
35
|
+
}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from setuptools import setup
|
|
2
|
+
from mlx import extension
|
|
3
|
+
|
|
4
|
+
if __name__ == "__main__":
|
|
5
|
+
setup(
|
|
6
|
+
name="mlx_cluster",
|
|
7
|
+
version="0.0.1",
|
|
8
|
+
description="Sample C++ and Metal extensions for MLX primitives.",
|
|
9
|
+
ext_modules=[extension.CMakeExtension("mlx_cluster._ext")],
|
|
10
|
+
cmdclass={"build_ext": extension.CMakeBuild},
|
|
11
|
+
packages=["mlx_cluster"],
|
|
12
|
+
package_data={"mlx_cluster": ["*.so", "*.dylib", "*.metallib"]},
|
|
13
|
+
extras_require={"dev": [], "test": ["torch_geometric", "pytest"]},
|
|
14
|
+
zip_safe=False,
|
|
15
|
+
python_requires=">=3.8",
|
|
16
|
+
)
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import mlx.core as mx
|
|
2
|
+
import numpy as np
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
# Torch dataset
|
|
6
|
+
import torch
|
|
7
|
+
import torch_geometric.datasets as pyg_datasets
|
|
8
|
+
from torch_geometric.utils import sort_edge_index
|
|
9
|
+
from torch_geometric.utils.num_nodes import maybe_num_nodes
|
|
10
|
+
from torch_geometric.utils.sparse import index2ptr
|
|
11
|
+
from torch.utils.data import DataLoader
|
|
12
|
+
|
|
13
|
+
torch_planetoid = pyg_datasets.Planetoid(root="data/Cora", name="Cora")
|
|
14
|
+
edge_index_torch = torch_planetoid.edge_index
|
|
15
|
+
num_nodes = maybe_num_nodes(edge_index=edge_index_torch)
|
|
16
|
+
row, col = sort_edge_index(edge_index=edge_index_torch, num_nodes=num_nodes)
|
|
17
|
+
row_ptr, col = index2ptr(row, num_nodes), col
|
|
18
|
+
loader = DataLoader(range(2708), batch_size=2000)
|
|
19
|
+
start_indices = next(iter(loader))
|
|
20
|
+
print(edge_index_torch.dtype)
|
|
21
|
+
print(row_ptr.dtype)
|
|
22
|
+
print(col.dtype)
|
|
23
|
+
print(start_indices.dtype)
|
|
24
|
+
random_walks = torch.ops.torch_cluster.random_walk(
|
|
25
|
+
row_ptr, col, start_indices, 5, 1.0, 1.0
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
from mlx_graphs.datasets import PlanetoidDataset
|
|
29
|
+
from mlx_graphs.utils.sorting import sort_edge_index
|
|
30
|
+
from torch.utils.data import DataLoader
|
|
31
|
+
from mlx_cluster import random_walk
|
|
32
|
+
|
|
33
|
+
cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
|
|
34
|
+
edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
|
|
35
|
+
sorted_edge_index = sort_edge_index(edge_index=edge_index)
|
|
36
|
+
print(edge_index.dtype)
|
|
37
|
+
row_mlx = sorted_edge_index[0][0]
|
|
38
|
+
col_mlx = sorted_edge_index[0][1]
|
|
39
|
+
_, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
|
|
40
|
+
cum_sum_mlx = counts_mlx.cumsum()
|
|
41
|
+
row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
|
|
42
|
+
start_indices = mx.array(start_indices.numpy())
|
|
43
|
+
print("Start indices data type is ", start_indices.dtype)
|
|
44
|
+
print("Col mlx data type is ", col_mlx.dtype)
|
|
45
|
+
print("Row mlx data type is ", row_ptr_mlx.dtype)
|
|
46
|
+
assert mx.array_equal(row_ptr_mlx, mx.array(row_ptr.numpy())), "Arrays not equal"
|
|
47
|
+
assert mx.array_equal(col_mlx, mx.array(col.numpy())), "Col arrays are not equal"
|
|
48
|
+
rand_data = mx.random.uniform(shape=[start_indices.shape[0], 5])
|
|
49
|
+
start_time = time.time()
|
|
50
|
+
print("Start indices data type is ", start_indices.dtype)
|
|
51
|
+
print("Col mlx data type is ", col_mlx.dtype)
|
|
52
|
+
print("Row mlx data type is ", row_ptr_mlx.dtype)
|
|
53
|
+
node_sequence = random_walk(
|
|
54
|
+
row_ptr_mlx, col_mlx, start_indices, rand_data, 5, stream=mx.gpu
|
|
55
|
+
)
|
|
56
|
+
# print("Time taken to complete 1000 random walks : ", time.time() - start_time)
|
|
57
|
+
print("Torch random walks are", random_walks[0])
|
|
58
|
+
print("MLX random walks are", node_sequence)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import mlx.core as mx
|
|
2
|
+
import numpy as np
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
# Torch dataset
|
|
6
|
+
import torch
|
|
7
|
+
import torch_geometric.datasets as pyg_datasets
|
|
8
|
+
from torch_geometric.utils import sort_edge_index
|
|
9
|
+
from torch_geometric.utils.num_nodes import maybe_num_nodes
|
|
10
|
+
from torch_geometric.utils.sparse import index2ptr
|
|
11
|
+
from torch.utils.data import DataLoader
|
|
12
|
+
|
|
13
|
+
torch_planetoid = pyg_datasets.Planetoid(root="data/Cora", name="Cora")
|
|
14
|
+
edge_index_torch = torch_planetoid.edge_index
|
|
15
|
+
num_nodes = maybe_num_nodes(edge_index=edge_index_torch)
|
|
16
|
+
row, col = sort_edge_index(edge_index=edge_index_torch, num_nodes=num_nodes)
|
|
17
|
+
row_ptr, col = index2ptr(row, num_nodes), col
|
|
18
|
+
loader = DataLoader(range(2708), batch_size=2000)
|
|
19
|
+
start_indices = next(iter(loader))
|
|
20
|
+
# random_walks = torch.ops.torch_cluster.random_walk(
|
|
21
|
+
# row_ptr, col, start_indices, 5, 1.0, 3.0
|
|
22
|
+
# )
|
|
23
|
+
|
|
24
|
+
from mlx_graphs.datasets import PlanetoidDataset
|
|
25
|
+
from mlx_graphs.utils.sorting import sort_edge_index
|
|
26
|
+
from torch.utils.data import DataLoader
|
|
27
|
+
from mlx_cluster import rejection_sampling
|
|
28
|
+
|
|
29
|
+
cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
|
|
30
|
+
edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
|
|
31
|
+
sorted_edge_index = sort_edge_index(edge_index=edge_index)
|
|
32
|
+
row_mlx = sorted_edge_index[0][0]
|
|
33
|
+
col_mlx = sorted_edge_index[0][1]
|
|
34
|
+
_, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
|
|
35
|
+
cum_sum_mlx = counts_mlx.cumsum()
|
|
36
|
+
row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
|
|
37
|
+
start_indices = mx.array(start_indices.numpy())
|
|
38
|
+
print("row pointer datatype", row_ptr_mlx.dtype)
|
|
39
|
+
print("col datatype", col_mlx.dtype)
|
|
40
|
+
print("start pointer datatype", start_indices.dtype)
|
|
41
|
+
assert mx.array_equal(row_ptr_mlx, mx.array(row_ptr.numpy())), "Arrays not equal"
|
|
42
|
+
assert mx.array_equal(col_mlx, mx.array(col.numpy())), "Col arrays are not equal"
|
|
43
|
+
rand_data = mx.random.uniform(shape=[start_indices.shape[0], 5])
|
|
44
|
+
start_time = time.time()
|
|
45
|
+
node_sequence = rejection_sampling(
|
|
46
|
+
row_ptr_mlx, col_mlx, start_indices, 5, 1.0, 3.0, stream=mx.cpu
|
|
47
|
+
)
|
|
48
|
+
# print("Time taken to complete 1000 random walks : ", time.time() - start_time)
|
|
49
|
+
print(node_sequence)
|