mlx-cluster 0.0.3__tar.gz → 0.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mlx_cluster-0.0.3/mlx_cluster.egg-info → mlx_cluster-0.0.5}/PKG-INFO +35 -15
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5}/README.md +22 -9
- mlx_cluster-0.0.5/bindings.cpp +81 -0
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5}/mlx_cluster/mlx_cluster.metallib +0 -0
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5/mlx_cluster.egg-info}/PKG-INFO +35 -15
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5}/mlx_cluster.egg-info/SOURCES.txt +0 -3
- mlx_cluster-0.0.5/mlx_cluster.egg-info/requires.txt +12 -0
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5}/pyproject.toml +19 -10
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5}/random_walks/BiasedRandomWalk.cpp +24 -30
- mlx_cluster-0.0.5/random_walks/BiasedRandomWalk.h +65 -0
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5}/random_walks/RandomWalk.cpp +43 -37
- mlx_cluster-0.0.5/random_walks/RandomWalk.h +62 -0
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5}/setup.py +2 -9
- mlx_cluster-0.0.5/tests/test_random_walk.py +72 -0
- mlx_cluster-0.0.5/tests/test_rejection_sampling.py +62 -0
- mlx_cluster-0.0.3/bindings.cpp +0 -65
- mlx_cluster-0.0.3/mlx_cluster/_ext.cpython-311-darwin.so +0 -0
- mlx_cluster-0.0.3/mlx_cluster/libmlx.dylib +0 -0
- mlx_cluster-0.0.3/mlx_cluster/libmlx_cluster.dylib +0 -0
- mlx_cluster-0.0.3/mlx_cluster.egg-info/requires.txt +0 -7
- mlx_cluster-0.0.3/random_walks/BiasedRandomWalk.h +0 -66
- mlx_cluster-0.0.3/random_walks/RandomWalk.h +0 -63
- mlx_cluster-0.0.3/tests/test_random_walk.py +0 -35
- mlx_cluster-0.0.3/tests/test_rejection_sampling.py +0 -34
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5}/CMakeLists.txt +0 -0
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5}/LICENSE +0 -0
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5}/MANIFEST.in +0 -0
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5}/mlx_cluster/__init__.py +0 -0
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5}/mlx_cluster.egg-info/dependency_links.txt +0 -0
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5}/mlx_cluster.egg-info/not-zip-safe +0 -0
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5}/mlx_cluster.egg-info/top_level.txt +0 -0
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5}/random_walks/random_walk.metal +0 -0
- {mlx_cluster-0.0.3 → mlx_cluster-0.0.5}/setup.cfg +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: mlx_cluster
|
|
3
|
-
Version: 0.0.
|
|
4
|
-
Summary: C++
|
|
3
|
+
Version: 0.0.5
|
|
4
|
+
Summary: C++ extension for generating random graphs
|
|
5
5
|
Author-email: Vinay Pandya <vinayharshadpandya27@gmail.com>
|
|
6
6
|
Project-URL: Homepage, https://github.com/vinayhpandya/mlx_cluster
|
|
7
7
|
Project-URL: Issues, https://github.com/vinayhpandya/mlx_cluster/Issues
|
|
@@ -9,15 +9,22 @@ Classifier: Development Status :: 3 - Alpha
|
|
|
9
9
|
Classifier: Programming Language :: Python :: 3
|
|
10
10
|
Classifier: Programming Language :: C++
|
|
11
11
|
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
-
Classifier: Operating System ::
|
|
12
|
+
Classifier: Operating System :: MacOS
|
|
13
13
|
Requires-Python: >=3.8
|
|
14
14
|
Description-Content-Type: text/markdown
|
|
15
15
|
License-File: LICENSE
|
|
16
16
|
Provides-Extra: dev
|
|
17
17
|
Provides-Extra: test
|
|
18
|
+
Requires-Dist: mlx-graphs>=0.0.8; extra == "test"
|
|
19
|
+
Requires-Dist: torch>=2.2.0; extra == "test"
|
|
20
|
+
Requires-Dist: mlx>=0.26.0; extra == "test"
|
|
18
21
|
Requires-Dist: pytest==7.4.4; extra == "test"
|
|
19
|
-
Requires-Dist:
|
|
20
|
-
Requires-Dist:
|
|
22
|
+
Requires-Dist: scipy>=1.13.0; extra == "test"
|
|
23
|
+
Requires-Dist: requests==2.31.0; extra == "test"
|
|
24
|
+
Requires-Dist: fsspec[http]==2024.2.0; extra == "test"
|
|
25
|
+
Requires-Dist: tqdm==4.66.1; extra == "test"
|
|
26
|
+
Dynamic: license-file
|
|
27
|
+
Dynamic: requires-python
|
|
21
28
|
|
|
22
29
|
# mlx_cluster
|
|
23
30
|
|
|
@@ -50,24 +57,37 @@ for testing purposes you need to have `mlx-graphs` and `torch_geometric` instal
|
|
|
50
57
|
|
|
51
58
|
|
|
52
59
|
```
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
from
|
|
60
|
+
# Can also use mlx for generating starting indices
|
|
61
|
+
import torch
|
|
62
|
+
from torch.utils.data import DataLoader
|
|
63
|
+
|
|
64
|
+
loader = DataLoader(range(2708), batch_size=2000)
|
|
65
|
+
start_indices = next(iter(loader))
|
|
56
66
|
|
|
57
67
|
|
|
68
|
+
from mlx_graphs.datasets import PlanetoidDataset
|
|
69
|
+
from mlx_graphs.utils.sorting import sort_edge_index
|
|
70
|
+
from torch.utils.data import DataLoader
|
|
71
|
+
from mlx_cluster import random_walk
|
|
72
|
+
|
|
58
73
|
cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
|
|
59
|
-
|
|
60
|
-
start_time = time.time()
|
|
74
|
+
# For some reason int_64t and int_32t are not compatible
|
|
61
75
|
edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
|
|
62
|
-
|
|
76
|
+
|
|
77
|
+
# Convert edge index into a CSR matrix
|
|
63
78
|
sorted_edge_index = sort_edge_index(edge_index=edge_index)
|
|
64
79
|
row_mlx = sorted_edge_index[0][0]
|
|
65
80
|
col_mlx = sorted_edge_index[0][1]
|
|
66
|
-
|
|
81
|
+
_, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
|
|
67
82
|
cum_sum_mlx = counts_mlx.cumsum()
|
|
68
|
-
rand = mx.random.uniform(shape=[start.shape[0], 100])
|
|
69
83
|
row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
|
|
70
|
-
|
|
84
|
+
start_indices = mx.array(start_indices.numpy())
|
|
85
|
+
|
|
86
|
+
rand_data = mx.random.uniform(shape=[start_indices.shape[0], 5])
|
|
87
|
+
|
|
88
|
+
node_sequence = random_walk(
|
|
89
|
+
row_ptr_mlx, col_mlx, start_indices, rand_data, 5, stream=mx.cpu
|
|
90
|
+
)
|
|
71
91
|
```
|
|
72
92
|
|
|
73
93
|
## TODO
|
|
@@ -29,24 +29,37 @@ for testing purposes you need to have `mlx-graphs` and `torch_geometric` instal
|
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
```
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
from
|
|
32
|
+
# Can also use mlx for generating starting indices
|
|
33
|
+
import torch
|
|
34
|
+
from torch.utils.data import DataLoader
|
|
35
|
+
|
|
36
|
+
loader = DataLoader(range(2708), batch_size=2000)
|
|
37
|
+
start_indices = next(iter(loader))
|
|
35
38
|
|
|
36
39
|
|
|
40
|
+
from mlx_graphs.datasets import PlanetoidDataset
|
|
41
|
+
from mlx_graphs.utils.sorting import sort_edge_index
|
|
42
|
+
from torch.utils.data import DataLoader
|
|
43
|
+
from mlx_cluster import random_walk
|
|
44
|
+
|
|
37
45
|
cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
|
|
38
|
-
|
|
39
|
-
start_time = time.time()
|
|
46
|
+
# For some reason int_64t and int_32t are not compatible
|
|
40
47
|
edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
|
|
41
|
-
|
|
48
|
+
|
|
49
|
+
# Convert edge index into a CSR matrix
|
|
42
50
|
sorted_edge_index = sort_edge_index(edge_index=edge_index)
|
|
43
51
|
row_mlx = sorted_edge_index[0][0]
|
|
44
52
|
col_mlx = sorted_edge_index[0][1]
|
|
45
|
-
|
|
53
|
+
_, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
|
|
46
54
|
cum_sum_mlx = counts_mlx.cumsum()
|
|
47
|
-
rand = mx.random.uniform(shape=[start.shape[0], 100])
|
|
48
55
|
row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
|
|
49
|
-
|
|
56
|
+
start_indices = mx.array(start_indices.numpy())
|
|
57
|
+
|
|
58
|
+
rand_data = mx.random.uniform(shape=[start_indices.shape[0], 5])
|
|
59
|
+
|
|
60
|
+
node_sequence = random_walk(
|
|
61
|
+
row_ptr_mlx, col_mlx, start_indices, rand_data, 5, stream=mx.cpu
|
|
62
|
+
)
|
|
50
63
|
```
|
|
51
64
|
|
|
52
65
|
## TODO
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
#include <nanobind/nanobind.h>
|
|
2
|
+
#include <nanobind/stl/variant.h>
|
|
3
|
+
#include <random_walks/RandomWalk.h>
|
|
4
|
+
#include <random_walks/BiasedRandomWalk.h>
|
|
5
|
+
|
|
6
|
+
namespace nb = nanobind;
|
|
7
|
+
using namespace nb::literals;
|
|
8
|
+
using namespace mlx::core;
|
|
9
|
+
|
|
10
|
+
NB_MODULE(_ext, m){
|
|
11
|
+
|
|
12
|
+
m.def(
|
|
13
|
+
"random_walk",
|
|
14
|
+
[](const mx::array& rowptr,
|
|
15
|
+
const mx::array& col,
|
|
16
|
+
const mx::array& start,
|
|
17
|
+
const mx::array& rand,
|
|
18
|
+
int walk_length,
|
|
19
|
+
nb::object stream = nb::none()) {
|
|
20
|
+
|
|
21
|
+
// call the real C++ implementation
|
|
22
|
+
auto outs = mlx_random_walk::random_walk(
|
|
23
|
+
rowptr, col, start, rand, walk_length,
|
|
24
|
+
stream.is_none() ? mx::StreamOrDevice{}
|
|
25
|
+
: nb::cast<mx::StreamOrDevice>(stream));
|
|
26
|
+
|
|
27
|
+
// vector -> tuple (move avoids a copy)
|
|
28
|
+
return nb::make_tuple(std::move(outs[0]), std::move(outs[1]));
|
|
29
|
+
},
|
|
30
|
+
"rowptr"_a, "col"_a, "start"_a, "rand"_a, "walk_length"_a,
|
|
31
|
+
nb::kw_only(), "stream"_a = nb::none(),
|
|
32
|
+
R"(
|
|
33
|
+
Uniform random walks.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
(nodes, edges) tuple of arrays
|
|
37
|
+
)",
|
|
38
|
+
nb::rv_policy::move);
|
|
39
|
+
|
|
40
|
+
m.def(
|
|
41
|
+
"rejection_sampling",
|
|
42
|
+
[](const mx::array& rowptr,
|
|
43
|
+
const mx::array& col,
|
|
44
|
+
const mx::array& start,
|
|
45
|
+
int walk_length,
|
|
46
|
+
float p,
|
|
47
|
+
float q,
|
|
48
|
+
nb::object stream = nb::none()
|
|
49
|
+
){
|
|
50
|
+
auto outs = mlx_biased_random_walk::rejection_sampling(
|
|
51
|
+
rowptr, col, start, walk_length, p, q,
|
|
52
|
+
stream.is_none() ? mx::StreamOrDevice{}
|
|
53
|
+
: nb::cast<mx::StreamOrDevice>(stream));
|
|
54
|
+
return nb::make_tuple(std::move(outs[0]), std::move(outs[1]));
|
|
55
|
+
},
|
|
56
|
+
"rowptr"_a,
|
|
57
|
+
"col"_a,
|
|
58
|
+
"start"_a,
|
|
59
|
+
"walk_length"_a,
|
|
60
|
+
"p"_a,
|
|
61
|
+
"q"_a,
|
|
62
|
+
nb::kw_only(), "stream"_a = nb::none(),
|
|
63
|
+
R"(
|
|
64
|
+
Sample nodes from the graph by sampling neighbors based
|
|
65
|
+
on probablity p and q
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
rowptr (array): rowptr of graph in csr format.
|
|
69
|
+
col (array): edges in csr format.
|
|
70
|
+
start (array): starting node of graph from which
|
|
71
|
+
biased sampling will be performed.
|
|
72
|
+
walk_length (int) : walk length of random graph
|
|
73
|
+
p : Likelihood of immediately revisiting a node in the walk.
|
|
74
|
+
q : Control parameter to interpolate between
|
|
75
|
+
breadth-first strategy and depth-first strategy
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
(nodes, edges) tuple of arrays
|
|
79
|
+
)",
|
|
80
|
+
nb::rv_policy::move);
|
|
81
|
+
}
|
|
Binary file
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: mlx_cluster
|
|
3
|
-
Version: 0.0.
|
|
4
|
-
Summary: C++
|
|
3
|
+
Version: 0.0.5
|
|
4
|
+
Summary: C++ extension for generating random graphs
|
|
5
5
|
Author-email: Vinay Pandya <vinayharshadpandya27@gmail.com>
|
|
6
6
|
Project-URL: Homepage, https://github.com/vinayhpandya/mlx_cluster
|
|
7
7
|
Project-URL: Issues, https://github.com/vinayhpandya/mlx_cluster/Issues
|
|
@@ -9,15 +9,22 @@ Classifier: Development Status :: 3 - Alpha
|
|
|
9
9
|
Classifier: Programming Language :: Python :: 3
|
|
10
10
|
Classifier: Programming Language :: C++
|
|
11
11
|
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
-
Classifier: Operating System ::
|
|
12
|
+
Classifier: Operating System :: MacOS
|
|
13
13
|
Requires-Python: >=3.8
|
|
14
14
|
Description-Content-Type: text/markdown
|
|
15
15
|
License-File: LICENSE
|
|
16
16
|
Provides-Extra: dev
|
|
17
17
|
Provides-Extra: test
|
|
18
|
+
Requires-Dist: mlx-graphs>=0.0.8; extra == "test"
|
|
19
|
+
Requires-Dist: torch>=2.2.0; extra == "test"
|
|
20
|
+
Requires-Dist: mlx>=0.26.0; extra == "test"
|
|
18
21
|
Requires-Dist: pytest==7.4.4; extra == "test"
|
|
19
|
-
Requires-Dist:
|
|
20
|
-
Requires-Dist:
|
|
22
|
+
Requires-Dist: scipy>=1.13.0; extra == "test"
|
|
23
|
+
Requires-Dist: requests==2.31.0; extra == "test"
|
|
24
|
+
Requires-Dist: fsspec[http]==2024.2.0; extra == "test"
|
|
25
|
+
Requires-Dist: tqdm==4.66.1; extra == "test"
|
|
26
|
+
Dynamic: license-file
|
|
27
|
+
Dynamic: requires-python
|
|
21
28
|
|
|
22
29
|
# mlx_cluster
|
|
23
30
|
|
|
@@ -50,24 +57,37 @@ for testing purposes you need to have `mlx-graphs` and `torch_geometric` instal
|
|
|
50
57
|
|
|
51
58
|
|
|
52
59
|
```
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
from
|
|
60
|
+
# Can also use mlx for generating starting indices
|
|
61
|
+
import torch
|
|
62
|
+
from torch.utils.data import DataLoader
|
|
63
|
+
|
|
64
|
+
loader = DataLoader(range(2708), batch_size=2000)
|
|
65
|
+
start_indices = next(iter(loader))
|
|
56
66
|
|
|
57
67
|
|
|
68
|
+
from mlx_graphs.datasets import PlanetoidDataset
|
|
69
|
+
from mlx_graphs.utils.sorting import sort_edge_index
|
|
70
|
+
from torch.utils.data import DataLoader
|
|
71
|
+
from mlx_cluster import random_walk
|
|
72
|
+
|
|
58
73
|
cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
|
|
59
|
-
|
|
60
|
-
start_time = time.time()
|
|
74
|
+
# For some reason int_64t and int_32t are not compatible
|
|
61
75
|
edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
|
|
62
|
-
|
|
76
|
+
|
|
77
|
+
# Convert edge index into a CSR matrix
|
|
63
78
|
sorted_edge_index = sort_edge_index(edge_index=edge_index)
|
|
64
79
|
row_mlx = sorted_edge_index[0][0]
|
|
65
80
|
col_mlx = sorted_edge_index[0][1]
|
|
66
|
-
|
|
81
|
+
_, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
|
|
67
82
|
cum_sum_mlx = counts_mlx.cumsum()
|
|
68
|
-
rand = mx.random.uniform(shape=[start.shape[0], 100])
|
|
69
83
|
row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
|
|
70
|
-
|
|
84
|
+
start_indices = mx.array(start_indices.numpy())
|
|
85
|
+
|
|
86
|
+
rand_data = mx.random.uniform(shape=[start_indices.shape[0], 5])
|
|
87
|
+
|
|
88
|
+
node_sequence = random_walk(
|
|
89
|
+
row_ptr_mlx, col_mlx, start_indices, rand_data, 5, stream=mx.cpu
|
|
90
|
+
)
|
|
71
91
|
```
|
|
72
92
|
|
|
73
93
|
## TODO
|
|
@@ -6,9 +6,6 @@ bindings.cpp
|
|
|
6
6
|
pyproject.toml
|
|
7
7
|
setup.py
|
|
8
8
|
mlx_cluster/__init__.py
|
|
9
|
-
mlx_cluster/_ext.cpython-311-darwin.so
|
|
10
|
-
mlx_cluster/libmlx.dylib
|
|
11
|
-
mlx_cluster/libmlx_cluster.dylib
|
|
12
9
|
mlx_cluster/mlx_cluster.metallib
|
|
13
10
|
mlx_cluster.egg-info/PKG-INFO
|
|
14
11
|
mlx_cluster.egg-info/SOURCES.txt
|
|
@@ -1,39 +1,48 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "mlx_cluster"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.5"
|
|
4
4
|
authors = [
|
|
5
5
|
{ name = "Vinay Pandya", email = "vinayharshadpandya27@gmail.com" },
|
|
6
6
|
]
|
|
7
|
-
description = "C++
|
|
7
|
+
description = "C++ extension for generating random graphs"
|
|
8
8
|
readme = "README.md"
|
|
9
|
-
requires-python = ">=3.
|
|
9
|
+
requires-python = ">=3.10"
|
|
10
10
|
classifiers = [
|
|
11
11
|
"Development Status :: 3 - Alpha",
|
|
12
12
|
"Programming Language :: Python :: 3",
|
|
13
13
|
"Programming Language :: C++",
|
|
14
14
|
"License :: OSI Approved :: MIT License",
|
|
15
|
-
"Operating System ::
|
|
15
|
+
"Operating System :: MacOS",
|
|
16
16
|
]
|
|
17
17
|
|
|
18
18
|
[project.optional-dependencies]
|
|
19
19
|
dev = []
|
|
20
20
|
test = [
|
|
21
|
+
"mlx-graphs>=0.0.8",
|
|
22
|
+
"torch>=2.2.0",
|
|
23
|
+
"mlx>=0.26.0",
|
|
21
24
|
"pytest==7.4.4",
|
|
22
|
-
"
|
|
23
|
-
"
|
|
25
|
+
"scipy>=1.13.0",
|
|
26
|
+
"requests==2.31.0",
|
|
27
|
+
"fsspec[http]==2024.2.0",
|
|
28
|
+
"tqdm==4.66.1",
|
|
24
29
|
]
|
|
30
|
+
|
|
25
31
|
[project.urls]
|
|
26
32
|
Homepage = "https://github.com/vinayhpandya/mlx_cluster"
|
|
27
33
|
Issues = "https://github.com/vinayhpandya/mlx_cluster/Issues"
|
|
28
34
|
|
|
35
|
+
[tool.pytest.ini_options]
|
|
36
|
+
addopts = "-ra"
|
|
37
|
+
markers = [
|
|
38
|
+
"slow: marks tests that download data, compile kernels, or are otherwise time-consuming (deselect with -m 'not slow')",
|
|
39
|
+
]
|
|
29
40
|
|
|
30
41
|
[build-system]
|
|
31
42
|
requires = [
|
|
43
|
+
"nanobind==2.4.0",
|
|
32
44
|
"setuptools>=42",
|
|
33
45
|
"cmake>=3.24",
|
|
34
|
-
"mlx>=0.
|
|
35
|
-
"nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
|
|
46
|
+
"mlx>=0.26.0",
|
|
36
47
|
]
|
|
37
|
-
|
|
38
|
-
|
|
39
48
|
build-backend = "setuptools.build_meta"
|
|
@@ -15,7 +15,8 @@
|
|
|
15
15
|
#endif
|
|
16
16
|
#include "random_walks/BiasedRandomWalk.h"
|
|
17
17
|
|
|
18
|
-
|
|
18
|
+
|
|
19
|
+
namespace mlx_biased_random_walk {
|
|
19
20
|
|
|
20
21
|
bool inline is_neighbor(const int64_t *rowptr, const int64_t *col, int64_t v,
|
|
21
22
|
int64_t w) {
|
|
@@ -27,21 +28,20 @@ namespace mlx::core {
|
|
|
27
28
|
return false;
|
|
28
29
|
}
|
|
29
30
|
|
|
30
|
-
void BiasedRandomWalk::eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
|
31
|
+
void BiasedRandomWalk::eval_cpu(const std::vector<mx::array>& inputs, std::vector<mx::array>& outputs) {
|
|
31
32
|
auto& rowptr = inputs[0];
|
|
32
33
|
auto& col = inputs[1];
|
|
33
34
|
auto& start = inputs[2];
|
|
34
35
|
auto& rand = inputs[3];
|
|
35
36
|
int numel = start.size();
|
|
36
|
-
|
|
37
|
+
std::cout<<"Inside biased random walk"<<std::endl;
|
|
37
38
|
// Initialize outputs
|
|
38
39
|
assert(outputs.size() == 2);
|
|
39
40
|
// Allocate memory for outputs if not already allocated
|
|
40
|
-
outputs[0].set_data(allocator::
|
|
41
|
-
outputs[1].set_data(allocator::
|
|
41
|
+
outputs[0].set_data(mx::allocator::malloc(numel*(walk_length_+1)*sizeof(int64_t)));
|
|
42
|
+
outputs[1].set_data(mx::allocator::malloc(numel*walk_length_*sizeof(int64_t)));
|
|
42
43
|
auto& n_out = outputs[0];
|
|
43
44
|
auto& e_out = outputs[1];
|
|
44
|
-
|
|
45
45
|
auto* n_out_ptr = n_out.data<int64_t>();
|
|
46
46
|
auto* e_out_ptr = e_out.data<int64_t>();
|
|
47
47
|
auto* start_values = start.data<int64_t>();
|
|
@@ -53,7 +53,7 @@ namespace mlx::core {
|
|
|
53
53
|
double prob_0 = 1. / p_ / max_prob;
|
|
54
54
|
double prob_1 = 1. / max_prob;
|
|
55
55
|
double prob_2 = 1. / q_ / max_prob;
|
|
56
|
-
|
|
56
|
+
|
|
57
57
|
for (int64_t n = 0; n < numel; n++) {
|
|
58
58
|
int64_t t = start_values[n], v, x, e_cur, row_start, row_end;
|
|
59
59
|
n_out_ptr[n * (walk_length_ + 1)] = t;
|
|
@@ -91,7 +91,6 @@ namespace mlx::core {
|
|
|
91
91
|
break;
|
|
92
92
|
}
|
|
93
93
|
}
|
|
94
|
-
|
|
95
94
|
n_out_ptr[n * (walk_length_ + 1) + (l + 1)] = x;
|
|
96
95
|
e_out_ptr[n * walk_length_ + l] = e_cur;
|
|
97
96
|
t = v;
|
|
@@ -101,9 +100,9 @@ namespace mlx::core {
|
|
|
101
100
|
|
|
102
101
|
};
|
|
103
102
|
|
|
104
|
-
std::vector<array> BiasedRandomWalk::jvp(
|
|
105
|
-
const std::vector<array>& primals,
|
|
106
|
-
const std::vector<array>& tangents,
|
|
103
|
+
std::vector<mx::array> BiasedRandomWalk::jvp(
|
|
104
|
+
const std::vector<mx::array>& primals,
|
|
105
|
+
const std::vector<mx::array>& tangents,
|
|
107
106
|
const std::vector<int>& argnums)
|
|
108
107
|
{
|
|
109
108
|
// Random walk is not differentiable, so we return zero tangents
|
|
@@ -121,8 +120,8 @@ namespace mlx::core {
|
|
|
121
120
|
// int numel = start.size();
|
|
122
121
|
|
|
123
122
|
// assert(outputs.size() == 2);
|
|
124
|
-
// outputs[0].set_data(allocator::
|
|
125
|
-
// outputs[1].set_data(allocator::
|
|
123
|
+
// outputs[0].set_data(allocator::malloc(numel * (walk_length_ + 1) * sizeof(int64_t)));
|
|
124
|
+
// outputs[1].set_data(allocator::malloc(numel * walk_length_ * sizeof(int64_t)));
|
|
126
125
|
// std::cout<<"after setting data"<<std::endl;
|
|
127
126
|
// auto& s = stream();
|
|
128
127
|
// auto& d = metal::device(s.device);
|
|
@@ -148,47 +147,42 @@ namespace mlx::core {
|
|
|
148
147
|
// }
|
|
149
148
|
// #endif
|
|
150
149
|
void BiasedRandomWalk::eval_gpu(
|
|
151
|
-
const std::vector<array>& inputs, std::vector<array>& outputs
|
|
150
|
+
const std::vector<mx::array>& inputs, std::vector<mx::array>& outputs
|
|
152
151
|
)
|
|
153
152
|
{
|
|
154
153
|
throw std::runtime_error("Random walk has no GPU implementation.");
|
|
155
154
|
}
|
|
156
|
-
std::vector<array> BiasedRandomWalk::vjp(
|
|
157
|
-
const std::vector<array>& primals,
|
|
158
|
-
const std::vector<array>& cotangents,
|
|
155
|
+
std::vector<mx::array> BiasedRandomWalk::vjp(
|
|
156
|
+
const std::vector<mx::array>& primals,
|
|
157
|
+
const std::vector<mx::array>& cotangents,
|
|
159
158
|
const std::vector<int>& argnums,
|
|
160
|
-
const std::vector<array>& outputs)
|
|
159
|
+
const std::vector<mx::array>& outputs)
|
|
161
160
|
{
|
|
162
161
|
// Random walk is not differentiable, so we return zero gradients
|
|
163
162
|
throw std::runtime_error("Random walk has no JVP implementation.");
|
|
164
163
|
}
|
|
165
164
|
|
|
166
|
-
std::pair<std::vector<array>, std::vector<int>> BiasedRandomWalk::vmap(
|
|
167
|
-
const std::vector<array>& inputs,
|
|
165
|
+
std::pair<std::vector<mx::array>, std::vector<int>> BiasedRandomWalk::vmap(
|
|
166
|
+
const std::vector<mx::array>& inputs,
|
|
168
167
|
const std::vector<int>& axes)
|
|
169
168
|
{
|
|
170
169
|
throw std::runtime_error("vmap not implemented for biasedRandomWalk");
|
|
171
170
|
}
|
|
172
171
|
|
|
173
|
-
bool BiasedRandomWalk::is_equivalent(const Primitive& other) const
|
|
174
|
-
{
|
|
175
|
-
throw std::runtime_error("biased Random walk has no GPU implementation.");
|
|
176
|
-
}
|
|
177
|
-
|
|
178
|
-
std::vector<std::vector<int>> BiasedRandomWalk::output_shapes(const std::vector<array>& inputs)
|
|
172
|
+
bool BiasedRandomWalk::is_equivalent(const mx::Primitive& other) const
|
|
179
173
|
{
|
|
180
174
|
throw std::runtime_error("biased Random walk has no GPU implementation.");
|
|
181
175
|
}
|
|
182
176
|
|
|
183
|
-
array rejection_sampling(const array& rowptr, const array& col, const array& start, int walk_length, const double p,
|
|
184
|
-
const double q, StreamOrDevice s)
|
|
177
|
+
std::vector<mx::array> rejection_sampling(const mx::array& rowptr, const mx::array& col, const mx::array& start, int walk_length, const double p,
|
|
178
|
+
const double q, mx::StreamOrDevice s)
|
|
185
179
|
{
|
|
186
180
|
int nodes = start.size();
|
|
187
181
|
auto primitive = std::make_shared<BiasedRandomWalk>(to_stream(s), walk_length, p, q);
|
|
188
|
-
return array::make_arrays({{nodes,walk_length+1},{nodes, walk_length}},
|
|
182
|
+
return mx::array::make_arrays({{nodes,walk_length+1},{nodes, walk_length}},
|
|
189
183
|
{rowptr.dtype(), rowptr.dtype()},
|
|
190
184
|
primitive,
|
|
191
185
|
{rowptr, col, start}
|
|
192
|
-
)
|
|
186
|
+
);
|
|
193
187
|
}
|
|
194
188
|
}
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <mlx/array.h>
|
|
4
|
+
#include <mlx/ops.h>
|
|
5
|
+
#include <mlx/primitives.h>
|
|
6
|
+
|
|
7
|
+
namespace mx = mlx::core;
|
|
8
|
+
namespace mlx_biased_random_walk{
|
|
9
|
+
|
|
10
|
+
class BiasedRandomWalk : public mx::Primitive {
|
|
11
|
+
public:
|
|
12
|
+
BiasedRandomWalk(mx::Stream stream, int walk_length, double p, double q)
|
|
13
|
+
: mx::Primitive(stream), walk_length_(walk_length), p_(p), q_(q) {}
|
|
14
|
+
void eval_cpu(const std::vector<mx::array>& inputs, std::vector<mx::array>& outputs)
|
|
15
|
+
override;
|
|
16
|
+
void eval_gpu(const std::vector<mx::array>& inputs, std::vector<mx::array>& outputs)
|
|
17
|
+
override;
|
|
18
|
+
|
|
19
|
+
/** The Jacobian-vector product. */
|
|
20
|
+
std::vector<mx::array> jvp(
|
|
21
|
+
const std::vector<mx::array>& primals,
|
|
22
|
+
const std::vector<mx::array>& tangents,
|
|
23
|
+
const std::vector<int>& argnums) override;
|
|
24
|
+
|
|
25
|
+
/** The vector-Jacobian product. */
|
|
26
|
+
std::vector<mx::array> vjp(
|
|
27
|
+
const std::vector<mx::array>& primals,
|
|
28
|
+
const std::vector<mx::array>& cotangents,
|
|
29
|
+
const std::vector<int>& argnums,
|
|
30
|
+
const std::vector<mx::array>& outputs) override;
|
|
31
|
+
|
|
32
|
+
/**
|
|
33
|
+
* The primitive must know how to vectorize itself across
|
|
34
|
+
* the given axes. The output is a pair containing the array
|
|
35
|
+
* representing the vectorized computation and the axis which
|
|
36
|
+
* corresponds to the output vectorized dimension.
|
|
37
|
+
*/
|
|
38
|
+
std::pair<std::vector<mx::array>, std::vector<int>> vmap(
|
|
39
|
+
const std::vector<mx::array>& inputs,
|
|
40
|
+
const std::vector<int>& axes) override;
|
|
41
|
+
|
|
42
|
+
/** Print the primitive. */
|
|
43
|
+
virtual const char* name() const override {
|
|
44
|
+
return "biased random walk implementation";
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
/** Equivalence check **/
|
|
48
|
+
bool is_equivalent(const mx::Primitive& other) const override;
|
|
49
|
+
|
|
50
|
+
private:
|
|
51
|
+
int walk_length_;
|
|
52
|
+
double p_;
|
|
53
|
+
double q_;
|
|
54
|
+
|
|
55
|
+
};
|
|
56
|
+
|
|
57
|
+
std::vector<mx::array> rejection_sampling(const mx::array& rowptr,
|
|
58
|
+
const mx::array& col,
|
|
59
|
+
const mx::array& start,
|
|
60
|
+
int walk_length,
|
|
61
|
+
const double p,
|
|
62
|
+
const double q,
|
|
63
|
+
mx::StreamOrDevice s = {});
|
|
64
|
+
|
|
65
|
+
};
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
#include <cassert>
|
|
2
2
|
#include <iostream>
|
|
3
3
|
#include <sstream>
|
|
4
|
-
|
|
4
|
+
#include <dlfcn.h>
|
|
5
5
|
#include "mlx/backend/common/copy.h"
|
|
6
6
|
#include "mlx/backend/common/utils.h"
|
|
7
7
|
#include "mlx/utils.h"
|
|
@@ -14,19 +14,31 @@
|
|
|
14
14
|
#endif
|
|
15
15
|
#include "random_walks/RandomWalk.h"
|
|
16
16
|
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
|
|
18
|
+
namespace mlx_random_walk {
|
|
19
|
+
std::string current_binary_dir() {
|
|
20
|
+
static std::string binary_dir = []() {
|
|
21
|
+
Dl_info info;
|
|
22
|
+
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
|
23
|
+
throw std::runtime_error("Unable to get current binary dir.");
|
|
24
|
+
}
|
|
25
|
+
return std::filesystem::path(info.dli_fname).parent_path().string();
|
|
26
|
+
}();
|
|
27
|
+
return binary_dir;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
void RandomWalk::eval_cpu(const std::vector<mx::array>& inputs, std::vector<mx::array>& outputs) {
|
|
19
31
|
auto& rowptr = inputs[0];
|
|
20
32
|
auto& col = inputs[1];
|
|
21
33
|
auto& start = inputs[2];
|
|
22
34
|
auto& rand = inputs[3];
|
|
23
35
|
int numel = start.size();
|
|
24
|
-
|
|
36
|
+
std::cout<<"Its really inside cpu"<<std::endl;
|
|
25
37
|
// Initialize outputs
|
|
26
38
|
assert(outputs.size() == 2);
|
|
27
39
|
// Allocate memory for outputs if not already allocated
|
|
28
|
-
outputs[0].set_data(allocator::
|
|
29
|
-
outputs[1].set_data(allocator::
|
|
40
|
+
outputs[0].set_data(mx::allocator::malloc(numel*(walk_length_+1)*sizeof(int64_t)));
|
|
41
|
+
outputs[1].set_data(mx::allocator::malloc(numel*walk_length_*sizeof(int64_t)));
|
|
30
42
|
auto& n_out = outputs[0];
|
|
31
43
|
auto& e_out = outputs[1];
|
|
32
44
|
|
|
@@ -37,7 +49,6 @@ namespace mlx::core {
|
|
|
37
49
|
auto* col_values = col.data<int64_t>();
|
|
38
50
|
auto* rand_values = rand.data<float>();
|
|
39
51
|
|
|
40
|
-
std::cout<<"After evaluating outputs"<<std::endl;
|
|
41
52
|
for (int64_t n = 0; n < numel; n++) {
|
|
42
53
|
int64_t n_cur = start_values[n];
|
|
43
54
|
n_out_ptr[n * (walk_length_ + 1)] = n_cur;
|
|
@@ -61,9 +72,9 @@ namespace mlx::core {
|
|
|
61
72
|
|
|
62
73
|
};
|
|
63
74
|
|
|
64
|
-
std::vector<array> RandomWalk::jvp(
|
|
65
|
-
const std::vector<array>& primals,
|
|
66
|
-
const std::vector<array>& tangents,
|
|
75
|
+
std::vector<mx::array> RandomWalk::jvp(
|
|
76
|
+
const std::vector<mx::array>& primals,
|
|
77
|
+
const std::vector<mx::array>& tangents,
|
|
67
78
|
const std::vector<int>& argnums)
|
|
68
79
|
{
|
|
69
80
|
// Random walk is not differentiable, so we return zero tangents
|
|
@@ -71,8 +82,8 @@ namespace mlx::core {
|
|
|
71
82
|
}
|
|
72
83
|
#ifdef _METAL_
|
|
73
84
|
void RandomWalk::eval_gpu(
|
|
74
|
-
const std::vector<array>& inputs,
|
|
75
|
-
std::vector<array>& outputs
|
|
85
|
+
const std::vector<mx::array>& inputs,
|
|
86
|
+
std::vector<mx::array>& outputs
|
|
76
87
|
){
|
|
77
88
|
auto& rowptr = inputs[0];
|
|
78
89
|
auto& col = inputs[1];
|
|
@@ -81,17 +92,16 @@ void RandomWalk::eval_gpu(
|
|
|
81
92
|
int numel = start.size();
|
|
82
93
|
|
|
83
94
|
assert(outputs.size() == 2);
|
|
84
|
-
outputs[0].set_data(allocator::
|
|
85
|
-
outputs[1].set_data(allocator::
|
|
86
|
-
std::cout<<"after setting data"<<std::endl;
|
|
95
|
+
outputs[0].set_data(mx::allocator::malloc(numel * (walk_length_ + 1) * sizeof(int64_t)));
|
|
96
|
+
outputs[1].set_data(mx::allocator::malloc(numel * walk_length_ * sizeof(int64_t)));
|
|
87
97
|
auto& s = stream();
|
|
88
|
-
auto& d = metal::device(s.device);
|
|
89
|
-
|
|
90
|
-
d.
|
|
91
|
-
auto kernel = d.get_kernel("random_walk",
|
|
98
|
+
auto& d = mx::metal::device(s.device);
|
|
99
|
+
std::cout<<"Its really inside gpu"<<std::endl;
|
|
100
|
+
auto lib = d.get_library("mlx_cluster", current_binary_dir());
|
|
101
|
+
auto kernel = d.get_kernel("random_walk", lib);
|
|
92
102
|
|
|
93
103
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
94
|
-
compute_encoder
|
|
104
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
95
105
|
|
|
96
106
|
compute_encoder.set_input_array(rowptr, 0);
|
|
97
107
|
compute_encoder.set_input_array(col, 1);
|
|
@@ -99,50 +109,46 @@ void RandomWalk::eval_gpu(
|
|
|
99
109
|
compute_encoder.set_input_array(rand, 3);
|
|
100
110
|
compute_encoder.set_output_array(outputs[0], 4);
|
|
101
111
|
compute_encoder.set_output_array(outputs[1], 5);
|
|
102
|
-
compute_encoder
|
|
112
|
+
compute_encoder.set_bytes(&walk_length_, sizeof(walk_length_), 6);
|
|
103
113
|
|
|
104
114
|
MTL::Size grid_size = MTL::Size(numel, 1, 1);
|
|
105
115
|
MTL::Size thread_group_size = MTL::Size(kernel->maxTotalThreadsPerThreadgroup(), 1, 1);
|
|
106
116
|
|
|
107
|
-
compute_encoder.
|
|
117
|
+
compute_encoder.dispatch_threads(grid_size, thread_group_size);
|
|
108
118
|
}
|
|
109
119
|
#endif
|
|
110
120
|
|
|
111
|
-
std::vector<array> RandomWalk::vjp(
|
|
112
|
-
const std::vector<array>& primals,
|
|
113
|
-
const std::vector<array>& cotangents,
|
|
121
|
+
std::vector<mx::array> RandomWalk::vjp(
|
|
122
|
+
const std::vector<mx::array>& primals,
|
|
123
|
+
const std::vector<mx::array>& cotangents,
|
|
114
124
|
const std::vector<int>& argnums,
|
|
115
|
-
const std::vector<array>& outputs)
|
|
125
|
+
const std::vector<mx::array>& outputs)
|
|
116
126
|
{
|
|
117
127
|
// Random walk is not differentiable, so we return zero gradients
|
|
118
128
|
throw std::runtime_error("Random walk has no GPU implementation.");
|
|
119
129
|
}
|
|
120
130
|
|
|
121
|
-
std::pair<std::vector<array>, std::vector<int>> RandomWalk::vmap(
|
|
122
|
-
const std::vector<array>& inputs,
|
|
131
|
+
std::pair<std::vector<mx::array>, std::vector<int>> RandomWalk::vmap(
|
|
132
|
+
const std::vector<mx::array>& inputs,
|
|
123
133
|
const std::vector<int>& axes)
|
|
124
134
|
{
|
|
125
135
|
throw std::runtime_error("vmap not implemented for RandomWalk");
|
|
126
136
|
}
|
|
127
137
|
|
|
128
|
-
bool RandomWalk::is_equivalent(const Primitive& other) const
|
|
129
|
-
{
|
|
130
|
-
throw std::runtime_error("Random walk has no GPU implementation.");
|
|
131
|
-
}
|
|
132
|
-
|
|
133
|
-
std::vector<std::vector<int>> RandomWalk::output_shapes(const std::vector<array>& inputs)
|
|
138
|
+
bool RandomWalk::is_equivalent(const mx::Primitive& other) const
|
|
134
139
|
{
|
|
135
140
|
throw std::runtime_error("Random walk has no GPU implementation.");
|
|
136
141
|
}
|
|
137
142
|
|
|
138
|
-
array random_walk(const array& rowptr, const array& col, const array& start, const array& rand, int walk_length, StreamOrDevice s)
|
|
143
|
+
std::vector<mx::array> random_walk(const mx::array& rowptr, const mx::array& col, const mx::array& start, const mx::array& rand, int walk_length, mx::StreamOrDevice s)
|
|
139
144
|
{
|
|
145
|
+
std::cout<<"Inside random walk"<<std::endl;
|
|
140
146
|
int nodes = start.size();
|
|
141
147
|
auto primitive = std::make_shared<RandomWalk>(walk_length, to_stream(s));
|
|
142
|
-
return array::make_arrays({{nodes,walk_length+1},{nodes, walk_length}},
|
|
148
|
+
return mx::array::make_arrays({{nodes,walk_length+1},{nodes, walk_length}},
|
|
143
149
|
{start.dtype(), start.dtype()},
|
|
144
150
|
primitive,
|
|
145
151
|
{rowptr, col, start, rand}
|
|
146
|
-
)
|
|
152
|
+
);
|
|
147
153
|
}
|
|
148
154
|
}
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <mlx/array.h>
|
|
4
|
+
#include <mlx/ops.h>
|
|
5
|
+
#include <mlx/primitives.h>
|
|
6
|
+
|
|
7
|
+
namespace mx = mlx::core;
|
|
8
|
+
namespace mlx_random_walk{
|
|
9
|
+
|
|
10
|
+
class RandomWalk : public mx::Primitive {
|
|
11
|
+
public:
|
|
12
|
+
explicit RandomWalk(int walk_length, mx::Stream stream):
|
|
13
|
+
mx::Primitive(stream), walk_length_(walk_length) {};
|
|
14
|
+
void eval_cpu(const std::vector<mx::array>& inputs, std::vector<mx::array>& outputs)
|
|
15
|
+
override;
|
|
16
|
+
void eval_gpu(const std::vector<mx::array>& inputs, std::vector<mx::array>& outputs)
|
|
17
|
+
override;
|
|
18
|
+
|
|
19
|
+
/** The Jacobian-vector product. */
|
|
20
|
+
std::vector<mx::array> jvp(
|
|
21
|
+
const std::vector<mx::array>& primals,
|
|
22
|
+
const std::vector<mx::array>& tangents,
|
|
23
|
+
const std::vector<int>& argnums) override;
|
|
24
|
+
|
|
25
|
+
/** The vector-Jacobian product. */
|
|
26
|
+
std::vector<mx::array> vjp(
|
|
27
|
+
const std::vector<mx::array>& primals,
|
|
28
|
+
const std::vector<mx::array>& cotangents,
|
|
29
|
+
const std::vector<int>& argnums,
|
|
30
|
+
const std::vector<mx::array>& outputs) override;
|
|
31
|
+
|
|
32
|
+
/**
|
|
33
|
+
* The primitive must know how to vectorize itself across
|
|
34
|
+
* the given axes. The output is a pair containing the array
|
|
35
|
+
* representing the vectorized computation and the axis which
|
|
36
|
+
* corresponds to the output vectorized dimension.
|
|
37
|
+
*/
|
|
38
|
+
std::pair<std::vector<mx::array>, std::vector<int>> vmap(
|
|
39
|
+
const std::vector<mx::array>& inputs,
|
|
40
|
+
const std::vector<int>& axes) override;
|
|
41
|
+
|
|
42
|
+
/** Print the primitive. */
|
|
43
|
+
virtual const char* name() const override {
|
|
44
|
+
return "Random walk implementation";
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
/** Equivalence check **/
|
|
48
|
+
bool is_equivalent(const mx::Primitive& other) const override;
|
|
49
|
+
|
|
50
|
+
private:
|
|
51
|
+
int walk_length_;
|
|
52
|
+
|
|
53
|
+
};
|
|
54
|
+
|
|
55
|
+
std::vector<mx::array> random_walk(const mx::array& rowptr,
|
|
56
|
+
const mx::array& col,
|
|
57
|
+
const mx::array& start,
|
|
58
|
+
const mx::array& rand,
|
|
59
|
+
int walk_length,
|
|
60
|
+
mx::StreamOrDevice s = {});
|
|
61
|
+
|
|
62
|
+
};
|
|
@@ -4,20 +4,13 @@ from mlx import extension
|
|
|
4
4
|
if __name__ == "__main__":
|
|
5
5
|
setup(
|
|
6
6
|
name="mlx_cluster",
|
|
7
|
-
version="0.0.
|
|
7
|
+
version="0.0.5",
|
|
8
8
|
description="Sample C++ and Metal extensions for MLX primitives.",
|
|
9
9
|
ext_modules=[extension.CMakeExtension("mlx_cluster._ext")],
|
|
10
10
|
cmdclass={"build_ext": extension.CMakeBuild},
|
|
11
11
|
packages=["mlx_cluster"],
|
|
12
12
|
package_data={"mlx_cluster": ["*.so", "*.dylib", "*.metallib"]},
|
|
13
|
-
extras_require={
|
|
14
|
-
"dev": [],
|
|
15
|
-
"test": [
|
|
16
|
-
"mlx_graphs",
|
|
17
|
-
"torch",
|
|
18
|
-
"pytest",
|
|
19
|
-
],
|
|
20
|
-
},
|
|
13
|
+
extras_require={"dev": []},
|
|
21
14
|
zip_safe=False,
|
|
22
15
|
python_requires=">=3.8",
|
|
23
16
|
)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import mlx.core as mx
|
|
2
|
+
import numpy as np
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
# Torch dataset
|
|
6
|
+
import torch
|
|
7
|
+
from torch.utils.data import DataLoader
|
|
8
|
+
|
|
9
|
+
loader = DataLoader(range(2708), batch_size=2000)
|
|
10
|
+
start_indices = next(iter(loader))
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
from mlx_graphs.datasets import PlanetoidDataset
|
|
14
|
+
from mlx_graphs.utils.sorting import sort_edge_index
|
|
15
|
+
from torch.utils.data import DataLoader
|
|
16
|
+
from mlx_cluster import random_walk
|
|
17
|
+
import pytest
|
|
18
|
+
import time
|
|
19
|
+
|
|
20
|
+
@pytest.mark.slow # give download/compile plenty of time on CI
|
|
21
|
+
def test_random_walk(tmp_path):
|
|
22
|
+
"""
|
|
23
|
+
Runs 1 000 random walks of length 10 on the Cora graph and checks:
|
|
24
|
+
1. output tensor shape == (num_start_nodes, walk_length + 1)
|
|
25
|
+
2. all returned node indices are valid ( < num_nodes )
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
# ---------- Dataset (downloaded to the temp dir) ----------
|
|
29
|
+
data_dir = tmp_path / "mlx_datasets"
|
|
30
|
+
cora = PlanetoidDataset(name="cora", base_dir=data_dir)
|
|
31
|
+
|
|
32
|
+
edge_index = cora.graphs[0].edge_index.astype(mx.int64)
|
|
33
|
+
|
|
34
|
+
# CSR conversion
|
|
35
|
+
sorted_edge_index = sort_edge_index(edge_index=edge_index)
|
|
36
|
+
row = sorted_edge_index[0][0]
|
|
37
|
+
col = sorted_edge_index[0][1]
|
|
38
|
+
_, counts = np.unique(np.array(row, copy=False), return_counts=True)
|
|
39
|
+
row_ptr = mx.concatenate([mx.array([0]), mx.array(counts.cumsum())])
|
|
40
|
+
|
|
41
|
+
# pick 1 000 random start nodes
|
|
42
|
+
num_starts = 1_000
|
|
43
|
+
rng = np.random.default_rng(42)
|
|
44
|
+
start_idx = mx.array(rng.integers(low=0, high=row.max().item() + 1,
|
|
45
|
+
size=num_starts, dtype=np.int64))
|
|
46
|
+
|
|
47
|
+
# random numbers for the kernel (shape [num_starts, walk_length])
|
|
48
|
+
walk_len = 10
|
|
49
|
+
rand_data = mx.random.uniform(shape=[num_starts, walk_len])
|
|
50
|
+
|
|
51
|
+
# ---------- Warm-up ----------
|
|
52
|
+
mx.eval(row_ptr, col, start_idx, rand_data)
|
|
53
|
+
|
|
54
|
+
# ---------- Run kernel ----------
|
|
55
|
+
t0 = time.time()
|
|
56
|
+
node_seq = random_walk(row_ptr, col, start_idx, rand_data,
|
|
57
|
+
walk_len, stream=mx.cpu)
|
|
58
|
+
elapsed = time.time() - t0
|
|
59
|
+
print(f"Random-walk kernel took {elapsed:.3f} s")
|
|
60
|
+
print("Node sequence is ", node_seq[0])
|
|
61
|
+
# ---------- Assertions ----------
|
|
62
|
+
assert node_seq[0].shape == (num_starts, walk_len + 1)
|
|
63
|
+
|
|
64
|
+
# num_nodes = cora.graphs[0].num_nodes
|
|
65
|
+
# assert (node_seq < num_nodes).all().item(), \
|
|
66
|
+
# "Random walk produced invalid node indices"
|
|
67
|
+
t0 = time.time()
|
|
68
|
+
node_seq_gpu = random_walk(row_ptr, col, start_idx, rand_data,
|
|
69
|
+
walk_len, stream=mx.gpu)
|
|
70
|
+
elapsed = time.time() - t0
|
|
71
|
+
print(f"Random-walk kernel on gpu took {elapsed:.3f} s")
|
|
72
|
+
print("Node sequence is ", node_seq_gpu[0])
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import mlx.core as mx
|
|
2
|
+
import numpy as np
|
|
3
|
+
import time
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
# Torch dataset
|
|
7
|
+
import torch
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
|
|
10
|
+
loader = DataLoader(range(2708), batch_size=2000)
|
|
11
|
+
start_indices = next(iter(loader))
|
|
12
|
+
# random_walks = torch.ops.torch_cluster.random_walk(
|
|
13
|
+
# row_ptr, col, start_indices, 5, 1.0, 3.0
|
|
14
|
+
# )
|
|
15
|
+
|
|
16
|
+
from mlx_graphs.datasets import PlanetoidDataset
|
|
17
|
+
from mlx_graphs.utils.sorting import sort_edge_index
|
|
18
|
+
from torch.utils.data import DataLoader
|
|
19
|
+
from mlx_cluster import rejection_sampling
|
|
20
|
+
|
|
21
|
+
@pytest.mark.slow # give download/compile plenty of time on CI
|
|
22
|
+
def test_random_walk(tmp_path):
|
|
23
|
+
"""
|
|
24
|
+
Runs 1 000 random walks of length 10 on the Cora graph and checks:
|
|
25
|
+
1. output tensor shape == (num_start_nodes, walk_length + 1)
|
|
26
|
+
2. all returned node indices are valid ( < num_nodes )
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
# ---------- Dataset (downloaded to the temp dir) ----------
|
|
30
|
+
data_dir = tmp_path / "mlx_datasets"
|
|
31
|
+
cora = PlanetoidDataset(name="cora", base_dir=data_dir)
|
|
32
|
+
|
|
33
|
+
edge_index = cora.graphs[0].edge_index.astype(mx.int64)
|
|
34
|
+
|
|
35
|
+
# CSR conversion
|
|
36
|
+
sorted_edge_index = sort_edge_index(edge_index=edge_index)
|
|
37
|
+
row = sorted_edge_index[0][0]
|
|
38
|
+
col = sorted_edge_index[0][1]
|
|
39
|
+
_, counts = np.unique(np.array(row, copy=False), return_counts=True)
|
|
40
|
+
row_ptr = mx.concatenate([mx.array([0]), mx.array(counts.cumsum())])
|
|
41
|
+
|
|
42
|
+
# pick 1 000 random start nodes
|
|
43
|
+
num_starts = 1_000
|
|
44
|
+
rng = np.random.default_rng(42)
|
|
45
|
+
start_idx = mx.array(rng.integers(low=0, high=row.max().item() + 1,
|
|
46
|
+
size=num_starts, dtype=np.int64))
|
|
47
|
+
|
|
48
|
+
# random numbers for the kernel (shape [num_starts, walk_length])
|
|
49
|
+
walk_len = 10
|
|
50
|
+
rand_data = mx.random.uniform(shape=[num_starts, walk_len])
|
|
51
|
+
|
|
52
|
+
# ---------- Warm-up ----------
|
|
53
|
+
mx.eval(row_ptr, col, start_idx, rand_data)
|
|
54
|
+
|
|
55
|
+
# ---------- Run kernel ----------
|
|
56
|
+
t0 = time.time()
|
|
57
|
+
node_seq = rejection_sampling(row_ptr, col, start_idx, walk_len, 1.0, 3.0, stream=mx.cpu)
|
|
58
|
+
elapsed = time.time() - t0
|
|
59
|
+
print(f"Random-walk kernel took {elapsed:.3f} s")
|
|
60
|
+
print("Node sequence is ", node_seq)
|
|
61
|
+
# ---------- Assertions ----------
|
|
62
|
+
assert node_seq[0].shape == (num_starts, walk_len + 1)
|
mlx_cluster-0.0.3/bindings.cpp
DELETED
|
@@ -1,65 +0,0 @@
|
|
|
1
|
-
#include <nanobind/nanobind.h>
|
|
2
|
-
#include <nanobind/stl/variant.h>
|
|
3
|
-
#include <random_walks/RandomWalk.h>
|
|
4
|
-
#include <random_walks/BiasedRandomWalk.h>
|
|
5
|
-
|
|
6
|
-
namespace nb = nanobind;
|
|
7
|
-
using namespace nb::literals;
|
|
8
|
-
using namespace mlx::core;
|
|
9
|
-
|
|
10
|
-
NB_MODULE(_ext, m){
|
|
11
|
-
|
|
12
|
-
m.def(
|
|
13
|
-
"random_walk",
|
|
14
|
-
&random_walk,
|
|
15
|
-
"rowptr"_a,
|
|
16
|
-
"col"_a,
|
|
17
|
-
"start"_a,
|
|
18
|
-
"rand"_a,
|
|
19
|
-
"walk_length"_a,
|
|
20
|
-
nb::kw_only(),
|
|
21
|
-
"stream"_a = nb::none(),
|
|
22
|
-
R"(
|
|
23
|
-
uniformly sample a graph
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
Args:
|
|
27
|
-
rowptr (array): rowptr of graph in csr format.
|
|
28
|
-
col (array): edges in csr format.
|
|
29
|
-
walk_length (int) : walk length of random graph
|
|
30
|
-
|
|
31
|
-
Returns:
|
|
32
|
-
array: consisting of nodes visited on random walk
|
|
33
|
-
)");
|
|
34
|
-
|
|
35
|
-
m.def(
|
|
36
|
-
"rejection_sampling",
|
|
37
|
-
&rejection_sampling,
|
|
38
|
-
"rowptr"_a,
|
|
39
|
-
"col"_a,
|
|
40
|
-
"start"_a,
|
|
41
|
-
"walk_length"_a,
|
|
42
|
-
"p"_a,
|
|
43
|
-
"q"_a,
|
|
44
|
-
nb::kw_only(),
|
|
45
|
-
"stream"_a = nb::none(),
|
|
46
|
-
R"(
|
|
47
|
-
Sample nodes from the graph by sampling neighbors based
|
|
48
|
-
on probablity p and q
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
Args:
|
|
52
|
-
rowptr (array): rowptr of graph in csr format.
|
|
53
|
-
col (array): edges in csr format.
|
|
54
|
-
start (array): starting node of graph from which
|
|
55
|
-
biased sampling will be performed.
|
|
56
|
-
walk_length (int) : walk length of random graph
|
|
57
|
-
p : Likelihood of immediately revisiting a node in the walk.
|
|
58
|
-
q : Control parameter to interpolate between
|
|
59
|
-
breadth-first strategy and depth-first strategy
|
|
60
|
-
|
|
61
|
-
Returns:
|
|
62
|
-
array: consisting of nodes visited on random walk
|
|
63
|
-
)");
|
|
64
|
-
}
|
|
65
|
-
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -1,66 +0,0 @@
|
|
|
1
|
-
#pragma once
|
|
2
|
-
|
|
3
|
-
#include <mlx/array.h>
|
|
4
|
-
#include <mlx/ops.h>
|
|
5
|
-
#include <mlx/primitives.h>
|
|
6
|
-
|
|
7
|
-
namespace mlx::core{
|
|
8
|
-
|
|
9
|
-
class BiasedRandomWalk : public Primitive {
|
|
10
|
-
public:
|
|
11
|
-
BiasedRandomWalk(Stream stream, int walk_length, double p, double q)
|
|
12
|
-
: Primitive(stream), walk_length_(walk_length), p_(p), q_(q) {}
|
|
13
|
-
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
14
|
-
override;
|
|
15
|
-
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
16
|
-
override;
|
|
17
|
-
|
|
18
|
-
/** The Jacobian-vector product. */
|
|
19
|
-
std::vector<array> jvp(
|
|
20
|
-
const std::vector<array>& primals,
|
|
21
|
-
const std::vector<array>& tangents,
|
|
22
|
-
const std::vector<int>& argnums) override;
|
|
23
|
-
|
|
24
|
-
/** The vector-Jacobian product. */
|
|
25
|
-
std::vector<array> vjp(
|
|
26
|
-
const std::vector<array>& primals,
|
|
27
|
-
const std::vector<array>& cotangents,
|
|
28
|
-
const std::vector<int>& argnums,
|
|
29
|
-
const std::vector<array>& outputs) override;
|
|
30
|
-
|
|
31
|
-
/**
|
|
32
|
-
* The primitive must know how to vectorize itself across
|
|
33
|
-
* the given axes. The output is a pair containing the array
|
|
34
|
-
* representing the vectorized computation and the axis which
|
|
35
|
-
* corresponds to the output vectorized dimension.
|
|
36
|
-
*/
|
|
37
|
-
std::pair<std::vector<array>, std::vector<int>> vmap(
|
|
38
|
-
const std::vector<array>& inputs,
|
|
39
|
-
const std::vector<int>& axes) override;
|
|
40
|
-
|
|
41
|
-
/** Print the primitive. */
|
|
42
|
-
void print(std::ostream& os) override {
|
|
43
|
-
os << "biased random walk implementation";
|
|
44
|
-
}
|
|
45
|
-
|
|
46
|
-
/** Equivalence check **/
|
|
47
|
-
bool is_equivalent(const Primitive& other) const override;
|
|
48
|
-
|
|
49
|
-
std::vector<std::vector<int>> output_shapes(const std::vector<array>& inputs) override;
|
|
50
|
-
|
|
51
|
-
private:
|
|
52
|
-
int walk_length_;
|
|
53
|
-
double p_;
|
|
54
|
-
double q_;
|
|
55
|
-
|
|
56
|
-
};
|
|
57
|
-
|
|
58
|
-
array rejection_sampling(const array& rowptr,
|
|
59
|
-
const array& col,
|
|
60
|
-
const array& start,
|
|
61
|
-
int walk_length,
|
|
62
|
-
const double p,
|
|
63
|
-
const double q,
|
|
64
|
-
StreamOrDevice s = {});
|
|
65
|
-
|
|
66
|
-
};
|
|
@@ -1,63 +0,0 @@
|
|
|
1
|
-
#pragma once
|
|
2
|
-
|
|
3
|
-
#include <mlx/array.h>
|
|
4
|
-
#include <mlx/ops.h>
|
|
5
|
-
#include <mlx/primitives.h>
|
|
6
|
-
|
|
7
|
-
namespace mlx::core{
|
|
8
|
-
|
|
9
|
-
class RandomWalk : public Primitive {
|
|
10
|
-
public:
|
|
11
|
-
explicit RandomWalk(int walk_length, Stream stream):
|
|
12
|
-
Primitive(stream), walk_length_(walk_length) {};
|
|
13
|
-
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
14
|
-
override;
|
|
15
|
-
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
16
|
-
override;
|
|
17
|
-
|
|
18
|
-
/** The Jacobian-vector product. */
|
|
19
|
-
std::vector<array> jvp(
|
|
20
|
-
const std::vector<array>& primals,
|
|
21
|
-
const std::vector<array>& tangents,
|
|
22
|
-
const std::vector<int>& argnums) override;
|
|
23
|
-
|
|
24
|
-
/** The vector-Jacobian product. */
|
|
25
|
-
std::vector<array> vjp(
|
|
26
|
-
const std::vector<array>& primals,
|
|
27
|
-
const std::vector<array>& cotangents,
|
|
28
|
-
const std::vector<int>& argnums,
|
|
29
|
-
const std::vector<array>& outputs) override;
|
|
30
|
-
|
|
31
|
-
/**
|
|
32
|
-
* The primitive must know how to vectorize itself across
|
|
33
|
-
* the given axes. The output is a pair containing the array
|
|
34
|
-
* representing the vectorized computation and the axis which
|
|
35
|
-
* corresponds to the output vectorized dimension.
|
|
36
|
-
*/
|
|
37
|
-
std::pair<std::vector<array>, std::vector<int>> vmap(
|
|
38
|
-
const std::vector<array>& inputs,
|
|
39
|
-
const std::vector<int>& axes) override;
|
|
40
|
-
|
|
41
|
-
/** Print the primitive. */
|
|
42
|
-
void print(std::ostream& os) override {
|
|
43
|
-
os << "Random walk implementation";
|
|
44
|
-
}
|
|
45
|
-
|
|
46
|
-
/** Equivalence check **/
|
|
47
|
-
bool is_equivalent(const Primitive& other) const override;
|
|
48
|
-
|
|
49
|
-
std::vector<std::vector<int>> output_shapes(const std::vector<array>& inputs) override;
|
|
50
|
-
|
|
51
|
-
private:
|
|
52
|
-
int walk_length_;
|
|
53
|
-
|
|
54
|
-
};
|
|
55
|
-
|
|
56
|
-
array random_walk(const array& rowptr,
|
|
57
|
-
const array& col,
|
|
58
|
-
const array& start,
|
|
59
|
-
const array& rand,
|
|
60
|
-
int walk_length,
|
|
61
|
-
StreamOrDevice s = {});
|
|
62
|
-
|
|
63
|
-
};
|
|
@@ -1,35 +0,0 @@
|
|
|
1
|
-
import mlx.core as mx
|
|
2
|
-
import numpy as np
|
|
3
|
-
import time
|
|
4
|
-
|
|
5
|
-
# Torch dataset
|
|
6
|
-
from torch.utils.data import DataLoader
|
|
7
|
-
|
|
8
|
-
loader = DataLoader(range(2708), batch_size=2000)
|
|
9
|
-
start_indices = next(iter(loader))
|
|
10
|
-
|
|
11
|
-
from mlx_graphs.datasets import PlanetoidDataset
|
|
12
|
-
from mlx_graphs.utils.sorting import sort_edge_index
|
|
13
|
-
from torch.utils.data import DataLoader
|
|
14
|
-
from mlx_cluster import random_walk
|
|
15
|
-
|
|
16
|
-
cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
|
|
17
|
-
# For some reason int_64t and int_32t are not compatible
|
|
18
|
-
edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
|
|
19
|
-
# Convert edge index into a CSR matrix
|
|
20
|
-
sorted_edge_index = sort_edge_index(edge_index=edge_index)
|
|
21
|
-
row_mlx = sorted_edge_index[0][0]
|
|
22
|
-
col_mlx = sorted_edge_index[0][1]
|
|
23
|
-
_, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
|
|
24
|
-
cum_sum_mlx = counts_mlx.cumsum()
|
|
25
|
-
row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
|
|
26
|
-
start_indices = mx.array(start_indices.numpy())
|
|
27
|
-
|
|
28
|
-
rand_data = mx.random.uniform(shape=[start_indices.shape[0], 5])
|
|
29
|
-
start_time = time.time()
|
|
30
|
-
|
|
31
|
-
node_sequence = random_walk(
|
|
32
|
-
row_ptr_mlx, col_mlx, start_indices, rand_data, 5, stream=mx.cpu
|
|
33
|
-
)
|
|
34
|
-
print("Time taken to complete 1000 random walks : ", time.time() - start_time)
|
|
35
|
-
print("MLX random walks are", node_sequence)
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
import mlx.core as mx
|
|
2
|
-
import numpy as np
|
|
3
|
-
import time
|
|
4
|
-
|
|
5
|
-
# Torch dataloader
|
|
6
|
-
from torch.utils.data import DataLoader
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
loader = DataLoader(range(2708), batch_size=2000)
|
|
10
|
-
start_indices = next(iter(loader))
|
|
11
|
-
# random_walks = torch.ops.torch_cluster.random_walk(
|
|
12
|
-
# row_ptr, col, start_indices, 5, 1.0, 3.0
|
|
13
|
-
# )
|
|
14
|
-
|
|
15
|
-
from mlx_graphs.datasets import PlanetoidDataset
|
|
16
|
-
from mlx_graphs.utils.sorting import sort_edge_index
|
|
17
|
-
from mlx_cluster import rejection_sampling
|
|
18
|
-
|
|
19
|
-
cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
|
|
20
|
-
edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
|
|
21
|
-
sorted_edge_index = sort_edge_index(edge_index=edge_index)
|
|
22
|
-
row_mlx = sorted_edge_index[0][0]
|
|
23
|
-
col_mlx = sorted_edge_index[0][1]
|
|
24
|
-
_, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
|
|
25
|
-
cum_sum_mlx = counts_mlx.cumsum()
|
|
26
|
-
row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
|
|
27
|
-
start_indices = mx.array(start_indices.numpy())
|
|
28
|
-
rand_data = mx.random.uniform(shape=[start_indices.shape[0], 5])
|
|
29
|
-
start_time = time.time()
|
|
30
|
-
node_sequence = rejection_sampling(
|
|
31
|
-
row_ptr_mlx, col_mlx, start_indices, 5, 1.0, 3.0, stream=mx.cpu
|
|
32
|
-
)
|
|
33
|
-
print("Time taken to complete random walks : ", time.time() - start_time)
|
|
34
|
-
print(node_sequence)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|