mlx-cluster 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,66 @@
1
+ cmake_minimum_required(VERSION 3.27)
2
+ project(_ext LANGUAGES CXX)
3
+
4
+ # ----- Setup required -----
5
+ set(CMAKE_CXX_STANDARD 17)
6
+ set(CMAKE_CXX_STANDARD_REQUIRED ON)
7
+ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
8
+
9
+ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
10
+
11
+ # ----- Dependencies required ----
12
+ find_package(fmt REQUIRED)
13
+ find_package(MLX CONFIG REQUIRED)
14
+ find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
15
+ execute_process(
16
+ COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
17
+ OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
18
+ list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
19
+ find_package(nanobind CONFIG REQUIRED)
20
+
21
+ # ------ Adding extensions to the library -----
22
+
23
+ # Add library
24
+ add_library(mlx_cluster)
25
+
26
+ target_sources(mlx_cluster
27
+ PUBLIC
28
+ ${CMAKE_CURRENT_LIST_DIR}/random_walks/RandomWalk.cpp
29
+ ${CMAKE_CURRENT_LIST_DIR}/random_walks/BiasedRandomWalk.cpp
30
+ )
31
+
32
+ target_include_directories(mlx_cluster
33
+ PUBLIC
34
+ ${CMAKE_CURRENT_LIST_DIR})
35
+
36
+ target_link_libraries(mlx_cluster PUBLIC mlx)
37
+
38
+
39
+ if(MLX_BUILD_METAL)
40
+ mlx_build_metallib(
41
+ TARGET mlx_cluster_metallib
42
+ TITLE mlx_cluster
43
+ SOURCES ${CMAKE_CURRENT_LIST_DIR}/random_walks/random_walk.metal
44
+ INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
45
+ OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
46
+ )
47
+
48
+ add_dependencies(
49
+ mlx_cluster
50
+ mlx_cluster_metallib
51
+ )
52
+
53
+ endif()
54
+ # ----- Nanobind module -----
55
+ nanobind_add_module(
56
+ _ext
57
+ NB_STATIC STABLE_ABI LTO NOMINSIZE
58
+ NB_DOMAIN mlx
59
+ ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
60
+ )
61
+
62
+ target_link_libraries(_ext PRIVATE mlx_cluster)
63
+
64
+ if(BUILD_SHARED_LIBS)
65
+ target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
66
+ endif()
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 vinayhpandya
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ include CMakeLists.txt README.md bindings.cpp LICENSE random_walks/*.*
@@ -0,0 +1,74 @@
1
+ Metadata-Version: 2.1
2
+ Name: mlx_cluster
3
+ Version: 0.0.1
4
+ Summary: C++ and Metal extensions for MLX CTC Loss
5
+ Author-email: Vinay Pandya <vinayharshadpandya27@gmail.com>
6
+ Project-URL: Homepage, https://github.com/vinayhpandya/mlx_cluster
7
+ Project-URL: Issues, https://github.com/vinayhpandya/mlx_cluster/Issues
8
+ Classifier: Development Status :: 3 - Alpha
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: C++
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Operating System :: OS Independent
13
+ Requires-Python: >=3.8
14
+ Description-Content-Type: text/markdown
15
+ License-File: LICENSE
16
+ Provides-Extra: dev
17
+ Provides-Extra: test
18
+ Requires-Dist: torch_geometric; extra == "test"
19
+ Requires-Dist: pytest; extra == "test"
20
+
21
+ # mlx_cluster
22
+
23
+ A C++ extension for generating ramdom walks for Homogeneous graphs using mlx
24
+
25
+ ## Installation
26
+
27
+ To install the necessary dependencies:
28
+
29
+ Clone the repositories:
30
+ ```bash
31
+ git clone https://github.com/vinayhpandya/mlx_cluster.git
32
+ ```
33
+
34
+ After cloning the repository install library using
35
+
36
+ ```bash
37
+ python setup.py build_ext -j8 --inplace
38
+ ```
39
+
40
+ for testing purposes you need to have `mlx-graphs` installed
41
+
42
+ ## Usage
43
+
44
+
45
+ ```
46
+ from mlx_graphs.utils.sorting import sort_edge_index
47
+ from mlx_graphs.loaders import Dataloader
48
+ from mlx_graphs_extension import random_walk
49
+
50
+
51
+ cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
52
+ start = mx.arange(0, 1000)
53
+ start_time = time.time()
54
+ edge_index = cora_dataset.graphs[0].edge_index
55
+ num_nodes = cora_dataset.graphs[0].num_nodes
56
+ sorted_edge_index = sort_edge_index(edge_index=edge_index)
57
+ row_mlx = sorted_edge_index[0][0]
58
+ col_mlx = sorted_edge_index[0][1]
59
+ unique_vals, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
60
+ cum_sum_mlx = counts_mlx.cumsum()
61
+ rand = mx.random.uniform(shape=[start.shape[0], 100])
62
+ row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
63
+ random_walk(row_ptr_mlx, col_mlx, start, rand, 1000, stream = mx.gpu)
64
+ ```
65
+
66
+ ## TODO
67
+
68
+ - [x] Add metal shaders to optimize the code
69
+ - [ ] Benchmark random walk against different frameworks
70
+ - [ ] Add more algorithms
71
+
72
+ ## Credits:
73
+
74
+ torch_cluster random walk implementation : [random_walk](https://github.com/rusty1s/pytorch_cluster/blob/master/csrc/cpu/rw_cpu.cpp)
@@ -0,0 +1,54 @@
1
+ # mlx_cluster
2
+
3
+ A C++ extension for generating ramdom walks for Homogeneous graphs using mlx
4
+
5
+ ## Installation
6
+
7
+ To install the necessary dependencies:
8
+
9
+ Clone the repositories:
10
+ ```bash
11
+ git clone https://github.com/vinayhpandya/mlx_cluster.git
12
+ ```
13
+
14
+ After cloning the repository install library using
15
+
16
+ ```bash
17
+ python setup.py build_ext -j8 --inplace
18
+ ```
19
+
20
+ for testing purposes you need to have `mlx-graphs` installed
21
+
22
+ ## Usage
23
+
24
+
25
+ ```
26
+ from mlx_graphs.utils.sorting import sort_edge_index
27
+ from mlx_graphs.loaders import Dataloader
28
+ from mlx_graphs_extension import random_walk
29
+
30
+
31
+ cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
32
+ start = mx.arange(0, 1000)
33
+ start_time = time.time()
34
+ edge_index = cora_dataset.graphs[0].edge_index
35
+ num_nodes = cora_dataset.graphs[0].num_nodes
36
+ sorted_edge_index = sort_edge_index(edge_index=edge_index)
37
+ row_mlx = sorted_edge_index[0][0]
38
+ col_mlx = sorted_edge_index[0][1]
39
+ unique_vals, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
40
+ cum_sum_mlx = counts_mlx.cumsum()
41
+ rand = mx.random.uniform(shape=[start.shape[0], 100])
42
+ row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
43
+ random_walk(row_ptr_mlx, col_mlx, start, rand, 1000, stream = mx.gpu)
44
+ ```
45
+
46
+ ## TODO
47
+
48
+ - [x] Add metal shaders to optimize the code
49
+ - [ ] Benchmark random walk against different frameworks
50
+ - [ ] Add more algorithms
51
+
52
+ ## Credits:
53
+
54
+ torch_cluster random walk implementation : [random_walk](https://github.com/rusty1s/pytorch_cluster/blob/master/csrc/cpu/rw_cpu.cpp)
@@ -0,0 +1,65 @@
1
+ #include <nanobind/nanobind.h>
2
+ #include <nanobind/stl/variant.h>
3
+ #include <random_walks/RandomWalk.h>
4
+ #include <random_walks/BiasedRandomWalk.h>
5
+
6
+ namespace nb = nanobind;
7
+ using namespace nb::literals;
8
+ using namespace mlx::core;
9
+
10
+ NB_MODULE(_ext, m){
11
+
12
+ m.def(
13
+ "random_walk",
14
+ &random_walk,
15
+ "rowptr"_a,
16
+ "col"_a,
17
+ "start"_a,
18
+ "rand"_a,
19
+ "walk_length"_a,
20
+ nb::kw_only(),
21
+ "stream"_a = nb::none(),
22
+ R"(
23
+ uniformly sample a graph
24
+
25
+
26
+ Args:
27
+ rowptr (array): rowptr of graph in csr format.
28
+ col (array): edges in csr format.
29
+ walk_length (int) : walk length of random graph
30
+
31
+ Returns:
32
+ array: consisting of nodes visited on random walk
33
+ )");
34
+
35
+ m.def(
36
+ "rejection_sampling",
37
+ &rejection_sampling,
38
+ "rowptr"_a,
39
+ "col"_a,
40
+ "start"_a,
41
+ "walk_length"_a,
42
+ "p"_a,
43
+ "q"_a,
44
+ nb::kw_only(),
45
+ "stream"_a = nb::none(),
46
+ R"(
47
+ Sample nodes from the graph by sampling neighbors based
48
+ on probablity p and q
49
+
50
+
51
+ Args:
52
+ rowptr (array): rowptr of graph in csr format.
53
+ col (array): edges in csr format.
54
+ start (array): starting node of graph from which
55
+ biased sampling will be performed.
56
+ walk_length (int) : walk length of random graph
57
+ p : Likelihood of immediately revisiting a node in the walk.
58
+ q : Control parameter to interpolate between
59
+ breadth-first strategy and depth-first strategy
60
+
61
+ Returns:
62
+ array: consisting of nodes visited on random walk
63
+ )");
64
+ }
65
+
@@ -0,0 +1,4 @@
1
+ import mlx.core as mx
2
+
3
+ from ._ext import random_walk
4
+ from ._ext import rejection_sampling
@@ -0,0 +1,74 @@
1
+ Metadata-Version: 2.1
2
+ Name: mlx_cluster
3
+ Version: 0.0.1
4
+ Summary: C++ and Metal extensions for MLX CTC Loss
5
+ Author-email: Vinay Pandya <vinayharshadpandya27@gmail.com>
6
+ Project-URL: Homepage, https://github.com/vinayhpandya/mlx_cluster
7
+ Project-URL: Issues, https://github.com/vinayhpandya/mlx_cluster/Issues
8
+ Classifier: Development Status :: 3 - Alpha
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: C++
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Operating System :: OS Independent
13
+ Requires-Python: >=3.8
14
+ Description-Content-Type: text/markdown
15
+ License-File: LICENSE
16
+ Provides-Extra: dev
17
+ Provides-Extra: test
18
+ Requires-Dist: torch_geometric; extra == "test"
19
+ Requires-Dist: pytest; extra == "test"
20
+
21
+ # mlx_cluster
22
+
23
+ A C++ extension for generating ramdom walks for Homogeneous graphs using mlx
24
+
25
+ ## Installation
26
+
27
+ To install the necessary dependencies:
28
+
29
+ Clone the repositories:
30
+ ```bash
31
+ git clone https://github.com/vinayhpandya/mlx_cluster.git
32
+ ```
33
+
34
+ After cloning the repository install library using
35
+
36
+ ```bash
37
+ python setup.py build_ext -j8 --inplace
38
+ ```
39
+
40
+ for testing purposes you need to have `mlx-graphs` installed
41
+
42
+ ## Usage
43
+
44
+
45
+ ```
46
+ from mlx_graphs.utils.sorting import sort_edge_index
47
+ from mlx_graphs.loaders import Dataloader
48
+ from mlx_graphs_extension import random_walk
49
+
50
+
51
+ cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
52
+ start = mx.arange(0, 1000)
53
+ start_time = time.time()
54
+ edge_index = cora_dataset.graphs[0].edge_index
55
+ num_nodes = cora_dataset.graphs[0].num_nodes
56
+ sorted_edge_index = sort_edge_index(edge_index=edge_index)
57
+ row_mlx = sorted_edge_index[0][0]
58
+ col_mlx = sorted_edge_index[0][1]
59
+ unique_vals, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
60
+ cum_sum_mlx = counts_mlx.cumsum()
61
+ rand = mx.random.uniform(shape=[start.shape[0], 100])
62
+ row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
63
+ random_walk(row_ptr_mlx, col_mlx, start, rand, 1000, stream = mx.gpu)
64
+ ```
65
+
66
+ ## TODO
67
+
68
+ - [x] Add metal shaders to optimize the code
69
+ - [ ] Benchmark random walk against different frameworks
70
+ - [ ] Add more algorithms
71
+
72
+ ## Credits:
73
+
74
+ torch_cluster random walk implementation : [random_walk](https://github.com/rusty1s/pytorch_cluster/blob/master/csrc/cpu/rw_cpu.cpp)
@@ -0,0 +1,24 @@
1
+ CMakeLists.txt
2
+ LICENSE
3
+ MANIFEST.in
4
+ README.md
5
+ bindings.cpp
6
+ pyproject.toml
7
+ setup.py
8
+ mlx_cluster/__init__.py
9
+ mlx_cluster/_ext.cpython-311-darwin.so
10
+ mlx_cluster/libmlx_cluster.dylib
11
+ mlx_cluster/mlx_cluster.metallib
12
+ mlx_cluster.egg-info/PKG-INFO
13
+ mlx_cluster.egg-info/SOURCES.txt
14
+ mlx_cluster.egg-info/dependency_links.txt
15
+ mlx_cluster.egg-info/not-zip-safe
16
+ mlx_cluster.egg-info/requires.txt
17
+ mlx_cluster.egg-info/top_level.txt
18
+ random_walks/BiasedRandomWalk.cpp
19
+ random_walks/BiasedRandomWalk.h
20
+ random_walks/RandomWalk.cpp
21
+ random_walks/RandomWalk.h
22
+ random_walks/random_walk.metal
23
+ tests/test_random_walk.py
24
+ tests/test_rejection_sampling.py
@@ -0,0 +1,6 @@
1
+
2
+ [dev]
3
+
4
+ [test]
5
+ torch_geometric
6
+ pytest
@@ -0,0 +1 @@
1
+ mlx_cluster
@@ -0,0 +1,33 @@
1
+ [project]
2
+ name = "mlx_cluster"
3
+ version = "0.0.1"
4
+ authors = [
5
+ { name = "Vinay Pandya", email = "vinayharshadpandya27@gmail.com" },
6
+ ]
7
+ description = "C++ and Metal extensions for MLX CTC Loss"
8
+ readme = "README.md"
9
+ requires-python = ">=3.8"
10
+ classifiers = [
11
+ "Development Status :: 3 - Alpha",
12
+ "Programming Language :: Python :: 3",
13
+ "Programming Language :: C++",
14
+ "License :: OSI Approved :: MIT License",
15
+ "Operating System :: OS Independent",
16
+ ]
17
+
18
+ [project.optional-dependencies]
19
+ dev = []
20
+
21
+ [project.urls]
22
+ Homepage = "https://github.com/vinayhpandya/mlx_cluster"
23
+ Issues = "https://github.com/vinayhpandya/mlx_cluster/Issues"
24
+
25
+
26
+ [build-system]
27
+ requires = [
28
+ "setuptools>=42",
29
+ "cmake>=3.24",
30
+ "mlx>=0.9.0",
31
+ "nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
32
+ ]
33
+ build-backend = "setuptools.build_meta"
@@ -0,0 +1,194 @@
1
+ #include <cassert>
2
+ #include <iostream>
3
+ #include <sstream>
4
+ #include <random>
5
+
6
+ #include "mlx/backend/common/copy.h"
7
+ #include "mlx/backend/common/utils.h"
8
+ #include "mlx/utils.h"
9
+ #include "mlx/random.h"
10
+ #include "mlx/ops.h"
11
+ #include "mlx/array.h"
12
+ #ifdef _METAL_
13
+ #include "mlx/backend/metal/device.h"
14
+ #include "mlx/backend/metal/utils.h"
15
+ #endif
16
+ #include "random_walks/BiasedRandomWalk.h"
17
+
18
+ namespace mlx::core {
19
+
20
+ bool inline is_neighbor(const int64_t *rowptr, const int64_t *col, int64_t v,
21
+ int64_t w) {
22
+ int64_t row_start = rowptr[v], row_end = rowptr[v + 1];
23
+ for (auto i = row_start; i < row_end; i++) {
24
+ if (col[i] == w)
25
+ return true;
26
+ }
27
+ return false;
28
+ }
29
+
30
+ void BiasedRandomWalk::eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) {
31
+ auto& rowptr = inputs[0];
32
+ auto& col = inputs[1];
33
+ auto& start = inputs[2];
34
+ auto& rand = inputs[3];
35
+ int numel = start.size();
36
+
37
+ // Initialize outputs
38
+ assert(outputs.size() == 2);
39
+ // Allocate memory for outputs if not already allocated
40
+ outputs[0].set_data(allocator::malloc_or_wait(numel*(walk_length_+1)*sizeof(int64_t)));
41
+ outputs[1].set_data(allocator::malloc_or_wait(numel*walk_length_*sizeof(int64_t)));
42
+ auto& n_out = outputs[0];
43
+ auto& e_out = outputs[1];
44
+
45
+ auto* n_out_ptr = n_out.data<int64_t>();
46
+ auto* e_out_ptr = e_out.data<int64_t>();
47
+ auto* start_values = start.data<int64_t>();
48
+ auto* row_ptr = rowptr.data<int64_t>();
49
+ auto* col_values = col.data<int64_t>();
50
+ auto* rand_values = rand.data<float>();
51
+
52
+ double max_prob = fmax(fmax(1. / p_, 1.), 1. / q_);
53
+ double prob_0 = 1. / p_ / max_prob;
54
+ double prob_1 = 1. / max_prob;
55
+ double prob_2 = 1. / q_ / max_prob;
56
+
57
+ for (int64_t n = 0; n < numel; n++) {
58
+ int64_t t = start_values[n], v, x, e_cur, row_start, row_end;
59
+ n_out_ptr[n * (walk_length_ + 1)] = t;
60
+ row_start = row_ptr[t], row_end = row_ptr[t + 1];
61
+ if (row_end - row_start == 0) {
62
+ e_cur = -1;
63
+ v = t;
64
+ } else {
65
+ e_cur = row_start + (std::rand() % (row_end - row_start));
66
+ v = col_values[e_cur];
67
+ }
68
+ n_out_ptr[n * (walk_length_ + 1) + 1] = v;
69
+ e_out_ptr[n * walk_length_] = e_cur;
70
+ for (auto l = 1; l < walk_length_; l++) {
71
+ row_start = row_ptr[v], row_end = row_ptr[v + 1];
72
+
73
+ if (row_end - row_start == 0) {
74
+ e_cur = -1;
75
+ x = v;
76
+ } else if (row_end - row_start == 1) {
77
+ e_cur = row_start;
78
+ x = col_values[e_cur];
79
+ } else {
80
+ while (true) {
81
+ e_cur = row_start + (std::rand() % (row_end - row_start));
82
+ x = col_values[e_cur];
83
+
84
+ auto r = ((double)std::rand() / (RAND_MAX)); // [0, 1)
85
+
86
+ if (x == t && r < prob_0)
87
+ break;
88
+ else if (is_neighbor(row_ptr, col_values, x, t) && r < prob_1)
89
+ break;
90
+ else if (r < prob_2)
91
+ break;
92
+ }
93
+ }
94
+
95
+ n_out_ptr[n * (walk_length_ + 1) + (l + 1)] = x;
96
+ e_out_ptr[n * walk_length_ + l] = e_cur;
97
+ t = v;
98
+ v = x;
99
+ }
100
+ }
101
+
102
+ };
103
+
104
+ std::vector<array> BiasedRandomWalk::jvp(
105
+ const std::vector<array>& primals,
106
+ const std::vector<array>& tangents,
107
+ const std::vector<int>& argnums)
108
+ {
109
+ // Random walk is not differentiable, so we return zero tangents
110
+ throw std::runtime_error("Biased random walk has no jvp implementation.");
111
+ }
112
+ // #ifdef _METAL_
113
+ // void BiasedRandomWalk::eval_gpu(
114
+ // const std::vector<array>& inputs,
115
+ // std::vector<array>& outputs
116
+ // ){
117
+ // auto& rowptr = inputs[0];
118
+ // auto& col = inputs[1];
119
+ // auto& start = inputs[2];
120
+ // auto& rand = inputs[3];
121
+ // int numel = start.size();
122
+
123
+ // assert(outputs.size() == 2);
124
+ // outputs[0].set_data(allocator::malloc_or_wait(numel * (walk_length_ + 1) * sizeof(int64_t)));
125
+ // outputs[1].set_data(allocator::malloc_or_wait(numel * walk_length_ * sizeof(int64_t)));
126
+ // std::cout<<"after setting data"<<std::endl;
127
+ // auto& s = stream();
128
+ // auto& d = metal::device(s.device);
129
+
130
+ // d.register_library("mlx_cluster", metal::get_colocated_mtllib_path);
131
+ // auto kernel = d.get_kernel("random_walk", "mlx_cluster");
132
+
133
+ // auto& compute_encoder = d.get_command_encoder(s.index);
134
+ // compute_encoder->setComputePipelineState(kernel);
135
+
136
+ // compute_encoder.set_input_array(rowptr, 0);
137
+ // compute_encoder.set_input_array(col, 1);
138
+ // compute_encoder.set_input_array(start, 2);
139
+ // compute_encoder.set_input_array(rand, 3);
140
+ // compute_encoder.set_output_array(outputs[0], 4);
141
+ // compute_encoder.set_output_array(outputs[1], 5);
142
+ // compute_encoder->setBytes(&walk_length_, sizeof(int32), 6);
143
+
144
+ // MTL::Size grid_size = MTL::Size(numel, 1, 1);
145
+ // MTL::Size thread_group_size = MTL::Size(kernel->maxTotalThreadsPerThreadgroup(), 1, 1);
146
+
147
+ // compute_encoder.dispatchThreads(grid_size, thread_group_size);
148
+ // }
149
+ // #endif
150
+ void BiasedRandomWalk::eval_gpu(
151
+ const std::vector<array>& inputs, std::vector<array>& outputs
152
+ )
153
+ {
154
+ throw std::runtime_error("Random walk has no GPU implementation.");
155
+ }
156
+ std::vector<array> BiasedRandomWalk::vjp(
157
+ const std::vector<array>& primals,
158
+ const std::vector<array>& cotangents,
159
+ const std::vector<int>& argnums,
160
+ const std::vector<array>& outputs)
161
+ {
162
+ // Random walk is not differentiable, so we return zero gradients
163
+ throw std::runtime_error("Random walk has no JVP implementation.");
164
+ }
165
+
166
+ std::pair<std::vector<array>, std::vector<int>> BiasedRandomWalk::vmap(
167
+ const std::vector<array>& inputs,
168
+ const std::vector<int>& axes)
169
+ {
170
+ throw std::runtime_error("vmap not implemented for biasedRandomWalk");
171
+ }
172
+
173
+ bool BiasedRandomWalk::is_equivalent(const Primitive& other) const
174
+ {
175
+ throw std::runtime_error("biased Random walk has no GPU implementation.");
176
+ }
177
+
178
+ std::vector<std::vector<int>> BiasedRandomWalk::output_shapes(const std::vector<array>& inputs)
179
+ {
180
+ throw std::runtime_error("biased Random walk has no GPU implementation.");
181
+ }
182
+
183
+ array rejection_sampling(const array& rowptr, const array& col, const array& start, int walk_length, const double p,
184
+ const double q, StreamOrDevice s)
185
+ {
186
+ int nodes = start.size();
187
+ auto primitive = std::make_shared<BiasedRandomWalk>(to_stream(s), walk_length, p, q);
188
+ return array::make_arrays({{nodes,walk_length+1},{nodes, walk_length}},
189
+ {rowptr.dtype(), rowptr.dtype()},
190
+ primitive,
191
+ {rowptr, col, start}
192
+ )[0];
193
+ }
194
+ }
@@ -0,0 +1,66 @@
1
+ #pragma once
2
+
3
+ #include <mlx/array.h>
4
+ #include <mlx/ops.h>
5
+ #include <mlx/primitives.h>
6
+
7
+ namespace mlx::core{
8
+
9
+ class BiasedRandomWalk : public Primitive {
10
+ public:
11
+ BiasedRandomWalk(Stream stream, int walk_length, double p, double q)
12
+ : Primitive(stream), walk_length_(walk_length), p_(p), q_(q) {}
13
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
14
+ override;
15
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
16
+ override;
17
+
18
+ /** The Jacobian-vector product. */
19
+ std::vector<array> jvp(
20
+ const std::vector<array>& primals,
21
+ const std::vector<array>& tangents,
22
+ const std::vector<int>& argnums) override;
23
+
24
+ /** The vector-Jacobian product. */
25
+ std::vector<array> vjp(
26
+ const std::vector<array>& primals,
27
+ const std::vector<array>& cotangents,
28
+ const std::vector<int>& argnums,
29
+ const std::vector<array>& outputs) override;
30
+
31
+ /**
32
+ * The primitive must know how to vectorize itself across
33
+ * the given axes. The output is a pair containing the array
34
+ * representing the vectorized computation and the axis which
35
+ * corresponds to the output vectorized dimension.
36
+ */
37
+ std::pair<std::vector<array>, std::vector<int>> vmap(
38
+ const std::vector<array>& inputs,
39
+ const std::vector<int>& axes) override;
40
+
41
+ /** Print the primitive. */
42
+ void print(std::ostream& os) override {
43
+ os << "biased random walk implementation";
44
+ }
45
+
46
+ /** Equivalence check **/
47
+ bool is_equivalent(const Primitive& other) const override;
48
+
49
+ std::vector<std::vector<int>> output_shapes(const std::vector<array>& inputs) override;
50
+
51
+ private:
52
+ int walk_length_;
53
+ double p_;
54
+ double q_;
55
+
56
+ };
57
+
58
+ array rejection_sampling(const array& rowptr,
59
+ const array& col,
60
+ const array& start,
61
+ int walk_length,
62
+ const double p,
63
+ const double q,
64
+ StreamOrDevice s = {});
65
+
66
+ };
@@ -0,0 +1,147 @@
1
+ #include <cassert>
2
+ #include <iostream>
3
+ #include <sstream>
4
+
5
+ #include "mlx/backend/common/copy.h"
6
+ #include "mlx/backend/common/utils.h"
7
+ #include "mlx/utils.h"
8
+ #include "mlx/random.h"
9
+ #include "mlx/ops.h"
10
+ #include "mlx/array.h"
11
+ #ifdef _METAL_
12
+ #include "mlx/backend/metal/device.h"
13
+ #include "mlx/backend/metal/utils.h"
14
+ #endif
15
+ #include "random_walks/RandomWalk.h"
16
+
17
+ namespace mlx::core {
18
+ void RandomWalk::eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) {
19
+ auto& rowptr = inputs[0];
20
+ auto& col = inputs[1];
21
+ auto& start = inputs[2];
22
+ auto& rand = inputs[3];
23
+ int numel = start.size();
24
+
25
+ // Initialize outputs
26
+ assert(outputs.size() == 2);
27
+ // Allocate memory for outputs if not already allocated
28
+ outputs[0].set_data(allocator::malloc_or_wait(numel*(walk_length_+1)*sizeof(int64_t)));
29
+ outputs[1].set_data(allocator::malloc_or_wait(numel*walk_length_*sizeof(int64_t)));
30
+ auto& n_out = outputs[0];
31
+ auto& e_out = outputs[1];
32
+
33
+ auto* n_out_ptr = n_out.data<int64_t>();
34
+ auto* e_out_ptr = e_out.data<int64_t>();
35
+ auto* start_values = start.data<int64_t>();
36
+ auto* row_ptr = rowptr.data<int64_t>();
37
+ auto* col_values = col.data<int64_t>();
38
+ auto* rand_values = rand.data<float>();
39
+
40
+ for (int64_t n = 0; n < numel; n++) {
41
+ int64_t n_cur = start_values[n];
42
+ n_out_ptr[n * (walk_length_ + 1)] = n_cur;
43
+ for (int l = 0; l < walk_length_; l++) {
44
+ int64_t row_start = row_ptr[n_cur];
45
+ int64_t row_end = row_ptr[n_cur+1];
46
+ int64_t e_cur;
47
+ if (row_end - row_start == 0) {
48
+ e_cur = -1;
49
+ } else {
50
+ float r = rand_values[n*walk_length_+l];
51
+ int64_t idx = static_cast<int64_t>(r * (row_end - row_start));
52
+ e_cur = row_start + idx;
53
+ n_cur = col_values[e_cur];
54
+ }
55
+
56
+ n_out_ptr[n * (walk_length_ + 1) + (l + 1)] = n_cur;
57
+ e_out_ptr[n * walk_length_ + l] = e_cur;
58
+ }
59
+ }
60
+
61
+ };
62
+
63
+ std::vector<array> RandomWalk::jvp(
64
+ const std::vector<array>& primals,
65
+ const std::vector<array>& tangents,
66
+ const std::vector<int>& argnums)
67
+ {
68
+ // Random walk is not differentiable, so we return zero tangents
69
+ throw std::runtime_error("Random walk has no GPU implementation.");
70
+ }
71
+ #ifdef _METAL_
72
+ void RandomWalk::eval_gpu(
73
+ const std::vector<array>& inputs,
74
+ std::vector<array>& outputs
75
+ ){
76
+ auto& rowptr = inputs[0];
77
+ auto& col = inputs[1];
78
+ auto& start = inputs[2];
79
+ auto& rand = inputs[3];
80
+ int numel = start.size();
81
+
82
+ assert(outputs.size() == 2);
83
+ outputs[0].set_data(allocator::malloc_or_wait(numel * (walk_length_ + 1) * sizeof(int64_t)));
84
+ outputs[1].set_data(allocator::malloc_or_wait(numel * walk_length_ * sizeof(int64_t)));
85
+ std::cout<<"after setting data"<<std::endl;
86
+ auto& s = stream();
87
+ auto& d = metal::device(s.device);
88
+
89
+ d.register_library("mlx_cluster");
90
+ auto kernel = d.get_kernel("random_walk", "mlx_cluster");
91
+
92
+ auto& compute_encoder = d.get_command_encoder(s.index);
93
+ compute_encoder->setComputePipelineState(kernel);
94
+
95
+ compute_encoder.set_input_array(rowptr, 0);
96
+ compute_encoder.set_input_array(col, 1);
97
+ compute_encoder.set_input_array(start, 2);
98
+ compute_encoder.set_input_array(rand, 3);
99
+ compute_encoder.set_output_array(outputs[0], 4);
100
+ compute_encoder.set_output_array(outputs[1], 5);
101
+ compute_encoder->setBytes(&walk_length_, sizeof(int32), 6);
102
+
103
+ MTL::Size grid_size = MTL::Size(numel, 1, 1);
104
+ MTL::Size thread_group_size = MTL::Size(kernel->maxTotalThreadsPerThreadgroup(), 1, 1);
105
+
106
+ compute_encoder.dispatchThreads(grid_size, thread_group_size);
107
+ }
108
+ #endif
109
+
110
+ std::vector<array> RandomWalk::vjp(
111
+ const std::vector<array>& primals,
112
+ const std::vector<array>& cotangents,
113
+ const std::vector<int>& argnums,
114
+ const std::vector<array>& outputs)
115
+ {
116
+ // Random walk is not differentiable, so we return zero gradients
117
+ throw std::runtime_error("Random walk has no GPU implementation.");
118
+ }
119
+
120
+ std::pair<std::vector<array>, std::vector<int>> RandomWalk::vmap(
121
+ const std::vector<array>& inputs,
122
+ const std::vector<int>& axes)
123
+ {
124
+ throw std::runtime_error("vmap not implemented for RandomWalk");
125
+ }
126
+
127
+ bool RandomWalk::is_equivalent(const Primitive& other) const
128
+ {
129
+ throw std::runtime_error("Random walk has no GPU implementation.");
130
+ }
131
+
132
+ std::vector<std::vector<int>> RandomWalk::output_shapes(const std::vector<array>& inputs)
133
+ {
134
+ throw std::runtime_error("Random walk has no GPU implementation.");
135
+ }
136
+
137
+ array random_walk(const array& rowptr, const array& col, const array& start, const array& rand, int walk_length, StreamOrDevice s)
138
+ {
139
+ int nodes = start.size();
140
+ auto primitive = std::make_shared<RandomWalk>(walk_length, to_stream(s));
141
+ return array::make_arrays({{nodes,walk_length+1},{nodes, walk_length}},
142
+ {start.dtype(), start.dtype()},
143
+ primitive,
144
+ {rowptr, col, start, rand}
145
+ )[0];
146
+ }
147
+ }
@@ -0,0 +1,63 @@
1
+ #pragma once
2
+
3
+ #include <mlx/array.h>
4
+ #include <mlx/ops.h>
5
+ #include <mlx/primitives.h>
6
+
7
+ namespace mlx::core{
8
+
9
+ class RandomWalk : public Primitive {
10
+ public:
11
+ explicit RandomWalk(int walk_length, Stream stream):
12
+ Primitive(stream), walk_length_(walk_length) {};
13
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
14
+ override;
15
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
16
+ override;
17
+
18
+ /** The Jacobian-vector product. */
19
+ std::vector<array> jvp(
20
+ const std::vector<array>& primals,
21
+ const std::vector<array>& tangents,
22
+ const std::vector<int>& argnums) override;
23
+
24
+ /** The vector-Jacobian product. */
25
+ std::vector<array> vjp(
26
+ const std::vector<array>& primals,
27
+ const std::vector<array>& cotangents,
28
+ const std::vector<int>& argnums,
29
+ const std::vector<array>& outputs) override;
30
+
31
+ /**
32
+ * The primitive must know how to vectorize itself across
33
+ * the given axes. The output is a pair containing the array
34
+ * representing the vectorized computation and the axis which
35
+ * corresponds to the output vectorized dimension.
36
+ */
37
+ std::pair<std::vector<array>, std::vector<int>> vmap(
38
+ const std::vector<array>& inputs,
39
+ const std::vector<int>& axes) override;
40
+
41
+ /** Print the primitive. */
42
+ void print(std::ostream& os) override {
43
+ os << "Random walk implementation";
44
+ }
45
+
46
+ /** Equivalence check **/
47
+ bool is_equivalent(const Primitive& other) const override;
48
+
49
+ std::vector<std::vector<int>> output_shapes(const std::vector<array>& inputs) override;
50
+
51
+ private:
52
+ int walk_length_;
53
+
54
+ };
55
+
56
+ array random_walk(const array& rowptr,
57
+ const array& col,
58
+ const array& start,
59
+ const array& rand,
60
+ int walk_length,
61
+ StreamOrDevice s = {});
62
+
63
+ };
@@ -0,0 +1,35 @@
1
+ #include <metal_stdlib>
2
+ #include "mlx/backend/metal/kernels/utils.h"
3
+ using namespace metal;
4
+
5
+ kernel void random_walk(
6
+ const device int64_t* rowptr [[buffer(0)]],
7
+ const device int64_t* col [[buffer(1)]],
8
+ const device int64_t* start [[buffer(2)]],
9
+ const device float* rand [[buffer(3)]],
10
+ device int64_t* n_out [[buffer(4)]],
11
+ device int64_t* e_out [[buffer(5)]],
12
+ constant int& walk_length [[buffer(6)]],
13
+ uint tid [[thread_position_in_grid]]
14
+ ) {
15
+ int64_t n_cur = start[tid];
16
+ n_out[tid * (walk_length + 1)] = n_cur;
17
+
18
+ for (int l = 0; l < walk_length; l++) {
19
+ int64_t row_start = rowptr[n_cur];
20
+ int64_t row_end = rowptr[n_cur + 1];
21
+ int64_t e_cur;
22
+
23
+ if (row_end - row_start == 0) {
24
+ e_cur = -1;
25
+ } else {
26
+ float r = rand[tid * walk_length + l];
27
+ int64_t idx = static_cast<int64_t>(r * (row_end - row_start));
28
+ e_cur = row_start + idx;
29
+ n_cur = col[e_cur];
30
+ }
31
+
32
+ n_out[tid * (walk_length + 1) + (l + 1)] = n_cur;
33
+ e_out[tid * walk_length + l] = e_cur;
34
+ }
35
+ }
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,16 @@
1
+ from setuptools import setup
2
+ from mlx import extension
3
+
4
+ if __name__ == "__main__":
5
+ setup(
6
+ name="mlx_cluster",
7
+ version="0.0.1",
8
+ description="Sample C++ and Metal extensions for MLX primitives.",
9
+ ext_modules=[extension.CMakeExtension("mlx_cluster._ext")],
10
+ cmdclass={"build_ext": extension.CMakeBuild},
11
+ packages=["mlx_cluster"],
12
+ package_data={"mlx_cluster": ["*.so", "*.dylib", "*.metallib"]},
13
+ extras_require={"dev": [], "test": ["torch_geometric", "pytest"]},
14
+ zip_safe=False,
15
+ python_requires=">=3.8",
16
+ )
@@ -0,0 +1,58 @@
1
+ import mlx.core as mx
2
+ import numpy as np
3
+ import time
4
+
5
+ # Torch dataset
6
+ import torch
7
+ import torch_geometric.datasets as pyg_datasets
8
+ from torch_geometric.utils import sort_edge_index
9
+ from torch_geometric.utils.num_nodes import maybe_num_nodes
10
+ from torch_geometric.utils.sparse import index2ptr
11
+ from torch.utils.data import DataLoader
12
+
13
+ torch_planetoid = pyg_datasets.Planetoid(root="data/Cora", name="Cora")
14
+ edge_index_torch = torch_planetoid.edge_index
15
+ num_nodes = maybe_num_nodes(edge_index=edge_index_torch)
16
+ row, col = sort_edge_index(edge_index=edge_index_torch, num_nodes=num_nodes)
17
+ row_ptr, col = index2ptr(row, num_nodes), col
18
+ loader = DataLoader(range(2708), batch_size=2000)
19
+ start_indices = next(iter(loader))
20
+ print(edge_index_torch.dtype)
21
+ print(row_ptr.dtype)
22
+ print(col.dtype)
23
+ print(start_indices.dtype)
24
+ random_walks = torch.ops.torch_cluster.random_walk(
25
+ row_ptr, col, start_indices, 5, 1.0, 1.0
26
+ )
27
+
28
+ from mlx_graphs.datasets import PlanetoidDataset
29
+ from mlx_graphs.utils.sorting import sort_edge_index
30
+ from torch.utils.data import DataLoader
31
+ from mlx_cluster import random_walk
32
+
33
+ cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
34
+ edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
35
+ sorted_edge_index = sort_edge_index(edge_index=edge_index)
36
+ print(edge_index.dtype)
37
+ row_mlx = sorted_edge_index[0][0]
38
+ col_mlx = sorted_edge_index[0][1]
39
+ _, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
40
+ cum_sum_mlx = counts_mlx.cumsum()
41
+ row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
42
+ start_indices = mx.array(start_indices.numpy())
43
+ print("Start indices data type is ", start_indices.dtype)
44
+ print("Col mlx data type is ", col_mlx.dtype)
45
+ print("Row mlx data type is ", row_ptr_mlx.dtype)
46
+ assert mx.array_equal(row_ptr_mlx, mx.array(row_ptr.numpy())), "Arrays not equal"
47
+ assert mx.array_equal(col_mlx, mx.array(col.numpy())), "Col arrays are not equal"
48
+ rand_data = mx.random.uniform(shape=[start_indices.shape[0], 5])
49
+ start_time = time.time()
50
+ print("Start indices data type is ", start_indices.dtype)
51
+ print("Col mlx data type is ", col_mlx.dtype)
52
+ print("Row mlx data type is ", row_ptr_mlx.dtype)
53
+ node_sequence = random_walk(
54
+ row_ptr_mlx, col_mlx, start_indices, rand_data, 5, stream=mx.gpu
55
+ )
56
+ # print("Time taken to complete 1000 random walks : ", time.time() - start_time)
57
+ print("Torch random walks are", random_walks[0])
58
+ print("MLX random walks are", node_sequence)
@@ -0,0 +1,49 @@
1
+ import mlx.core as mx
2
+ import numpy as np
3
+ import time
4
+
5
+ # Torch dataset
6
+ import torch
7
+ import torch_geometric.datasets as pyg_datasets
8
+ from torch_geometric.utils import sort_edge_index
9
+ from torch_geometric.utils.num_nodes import maybe_num_nodes
10
+ from torch_geometric.utils.sparse import index2ptr
11
+ from torch.utils.data import DataLoader
12
+
13
+ torch_planetoid = pyg_datasets.Planetoid(root="data/Cora", name="Cora")
14
+ edge_index_torch = torch_planetoid.edge_index
15
+ num_nodes = maybe_num_nodes(edge_index=edge_index_torch)
16
+ row, col = sort_edge_index(edge_index=edge_index_torch, num_nodes=num_nodes)
17
+ row_ptr, col = index2ptr(row, num_nodes), col
18
+ loader = DataLoader(range(2708), batch_size=2000)
19
+ start_indices = next(iter(loader))
20
+ # random_walks = torch.ops.torch_cluster.random_walk(
21
+ # row_ptr, col, start_indices, 5, 1.0, 3.0
22
+ # )
23
+
24
+ from mlx_graphs.datasets import PlanetoidDataset
25
+ from mlx_graphs.utils.sorting import sort_edge_index
26
+ from torch.utils.data import DataLoader
27
+ from mlx_cluster import rejection_sampling
28
+
29
+ cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
30
+ edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
31
+ sorted_edge_index = sort_edge_index(edge_index=edge_index)
32
+ row_mlx = sorted_edge_index[0][0]
33
+ col_mlx = sorted_edge_index[0][1]
34
+ _, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
35
+ cum_sum_mlx = counts_mlx.cumsum()
36
+ row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
37
+ start_indices = mx.array(start_indices.numpy())
38
+ print("row pointer datatype", row_ptr_mlx.dtype)
39
+ print("col datatype", col_mlx.dtype)
40
+ print("start pointer datatype", start_indices.dtype)
41
+ assert mx.array_equal(row_ptr_mlx, mx.array(row_ptr.numpy())), "Arrays not equal"
42
+ assert mx.array_equal(col_mlx, mx.array(col.numpy())), "Col arrays are not equal"
43
+ rand_data = mx.random.uniform(shape=[start_indices.shape[0], 5])
44
+ start_time = time.time()
45
+ node_sequence = rejection_sampling(
46
+ row_ptr_mlx, col_mlx, start_indices, 5, 1.0, 3.0, stream=mx.cpu
47
+ )
48
+ # print("Time taken to complete 1000 random walks : ", time.time() - start_time)
49
+ print(node_sequence)