mlx-cluster 0.0.4__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {mlx_cluster-0.0.4/mlx_cluster.egg-info → mlx_cluster-0.0.5}/PKG-INFO +35 -17
  2. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5}/README.md +22 -9
  3. mlx_cluster-0.0.5/bindings.cpp +81 -0
  4. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5}/mlx_cluster/mlx_cluster.metallib +0 -0
  5. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5/mlx_cluster.egg-info}/PKG-INFO +35 -17
  6. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5}/mlx_cluster.egg-info/SOURCES.txt +0 -3
  7. mlx_cluster-0.0.5/mlx_cluster.egg-info/requires.txt +12 -0
  8. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5}/pyproject.toml +20 -13
  9. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5}/random_walks/BiasedRandomWalk.cpp +24 -30
  10. mlx_cluster-0.0.5/random_walks/BiasedRandomWalk.h +65 -0
  11. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5}/random_walks/RandomWalk.cpp +42 -37
  12. mlx_cluster-0.0.5/random_walks/RandomWalk.h +62 -0
  13. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5}/setup.py +1 -1
  14. mlx_cluster-0.0.5/tests/test_random_walk.py +72 -0
  15. mlx_cluster-0.0.5/tests/test_rejection_sampling.py +62 -0
  16. mlx_cluster-0.0.4/bindings.cpp +0 -65
  17. mlx_cluster-0.0.4/mlx_cluster/_ext.cpython-311-darwin.so +0 -0
  18. mlx_cluster-0.0.4/mlx_cluster/libmlx.dylib +0 -0
  19. mlx_cluster-0.0.4/mlx_cluster/libmlx_cluster.dylib +0 -0
  20. mlx_cluster-0.0.4/mlx_cluster.egg-info/requires.txt +0 -9
  21. mlx_cluster-0.0.4/random_walks/BiasedRandomWalk.h +0 -66
  22. mlx_cluster-0.0.4/random_walks/RandomWalk.h +0 -63
  23. mlx_cluster-0.0.4/tests/test_random_walk.py +0 -38
  24. mlx_cluster-0.0.4/tests/test_rejection_sampling.py +0 -35
  25. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5}/CMakeLists.txt +0 -0
  26. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5}/LICENSE +0 -0
  27. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5}/MANIFEST.in +0 -0
  28. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5}/mlx_cluster/__init__.py +0 -0
  29. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5}/mlx_cluster.egg-info/dependency_links.txt +0 -0
  30. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5}/mlx_cluster.egg-info/not-zip-safe +0 -0
  31. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5}/mlx_cluster.egg-info/top_level.txt +0 -0
  32. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5}/random_walks/random_walk.metal +0 -0
  33. {mlx_cluster-0.0.4 → mlx_cluster-0.0.5}/setup.cfg +0 -0
@@ -1,7 +1,7 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: mlx_cluster
3
- Version: 0.0.4
4
- Summary: C++ and Metal extensions for MLX CTC Loss
3
+ Version: 0.0.5
4
+ Summary: C++ extension for generating random graphs
5
5
  Author-email: Vinay Pandya <vinayharshadpandya27@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/vinayhpandya/mlx_cluster
7
7
  Project-URL: Issues, https://github.com/vinayhpandya/mlx_cluster/Issues
@@ -9,17 +9,22 @@ Classifier: Development Status :: 3 - Alpha
9
9
  Classifier: Programming Language :: Python :: 3
10
10
  Classifier: Programming Language :: C++
11
11
  Classifier: License :: OSI Approved :: MIT License
12
- Classifier: Operating System :: OS Independent
12
+ Classifier: Operating System :: MacOS
13
13
  Requires-Python: >=3.8
14
14
  Description-Content-Type: text/markdown
15
15
  License-File: LICENSE
16
16
  Provides-Extra: dev
17
17
  Provides-Extra: test
18
- Requires-Dist: mlx_graphs==0.0.7; extra == "test"
19
- Requires-Dist: torch==2.2.0; extra == "test"
20
- Requires-Dist: mlx>=0.17.0; extra == "test"
18
+ Requires-Dist: mlx-graphs>=0.0.8; extra == "test"
19
+ Requires-Dist: torch>=2.2.0; extra == "test"
20
+ Requires-Dist: mlx>=0.26.0; extra == "test"
21
21
  Requires-Dist: pytest==7.4.4; extra == "test"
22
- Requires-Dist: scipy==1.12.0; extra == "test"
22
+ Requires-Dist: scipy>=1.13.0; extra == "test"
23
+ Requires-Dist: requests==2.31.0; extra == "test"
24
+ Requires-Dist: fsspec[http]==2024.2.0; extra == "test"
25
+ Requires-Dist: tqdm==4.66.1; extra == "test"
26
+ Dynamic: license-file
27
+ Dynamic: requires-python
23
28
 
24
29
  # mlx_cluster
25
30
 
@@ -52,24 +57,37 @@ for testing purposes you need to have `mlx-graphs` and `torch_geometric` instal
52
57
 
53
58
 
54
59
  ```
55
- from mlx_graphs.utils.sorting import sort_edge_index
56
- from mlx_graphs.loaders import Dataloader
57
- from mlx_graphs_extension import random_walk
60
+ # Can also use mlx for generating starting indices
61
+ import torch
62
+ from torch.utils.data import DataLoader
63
+
64
+ loader = DataLoader(range(2708), batch_size=2000)
65
+ start_indices = next(iter(loader))
58
66
 
59
67
 
68
+ from mlx_graphs.datasets import PlanetoidDataset
69
+ from mlx_graphs.utils.sorting import sort_edge_index
70
+ from torch.utils.data import DataLoader
71
+ from mlx_cluster import random_walk
72
+
60
73
  cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
61
- start = mx.arange(0, 1000)
62
- start_time = time.time()
74
+ # For some reason int_64t and int_32t are not compatible
63
75
  edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
64
- num_nodes = cora_dataset.graphs[0].num_nodes
76
+
77
+ # Convert edge index into a CSR matrix
65
78
  sorted_edge_index = sort_edge_index(edge_index=edge_index)
66
79
  row_mlx = sorted_edge_index[0][0]
67
80
  col_mlx = sorted_edge_index[0][1]
68
- unique_vals, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
81
+ _, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
69
82
  cum_sum_mlx = counts_mlx.cumsum()
70
- rand = mx.random.uniform(shape=[start.shape[0], 100])
71
83
  row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
72
- random_walk(row_ptr_mlx, col_mlx, start, rand, 1000, stream = mx.gpu)
84
+ start_indices = mx.array(start_indices.numpy())
85
+
86
+ rand_data = mx.random.uniform(shape=[start_indices.shape[0], 5])
87
+
88
+ node_sequence = random_walk(
89
+ row_ptr_mlx, col_mlx, start_indices, rand_data, 5, stream=mx.cpu
90
+ )
73
91
  ```
74
92
 
75
93
  ## TODO
@@ -29,24 +29,37 @@ for testing purposes you need to have `mlx-graphs` and `torch_geometric` instal
29
29
 
30
30
 
31
31
  ```
32
- from mlx_graphs.utils.sorting import sort_edge_index
33
- from mlx_graphs.loaders import Dataloader
34
- from mlx_graphs_extension import random_walk
32
+ # Can also use mlx for generating starting indices
33
+ import torch
34
+ from torch.utils.data import DataLoader
35
+
36
+ loader = DataLoader(range(2708), batch_size=2000)
37
+ start_indices = next(iter(loader))
35
38
 
36
39
 
40
+ from mlx_graphs.datasets import PlanetoidDataset
41
+ from mlx_graphs.utils.sorting import sort_edge_index
42
+ from torch.utils.data import DataLoader
43
+ from mlx_cluster import random_walk
44
+
37
45
  cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
38
- start = mx.arange(0, 1000)
39
- start_time = time.time()
46
+ # For some reason int_64t and int_32t are not compatible
40
47
  edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
41
- num_nodes = cora_dataset.graphs[0].num_nodes
48
+
49
+ # Convert edge index into a CSR matrix
42
50
  sorted_edge_index = sort_edge_index(edge_index=edge_index)
43
51
  row_mlx = sorted_edge_index[0][0]
44
52
  col_mlx = sorted_edge_index[0][1]
45
- unique_vals, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
53
+ _, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
46
54
  cum_sum_mlx = counts_mlx.cumsum()
47
- rand = mx.random.uniform(shape=[start.shape[0], 100])
48
55
  row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
49
- random_walk(row_ptr_mlx, col_mlx, start, rand, 1000, stream = mx.gpu)
56
+ start_indices = mx.array(start_indices.numpy())
57
+
58
+ rand_data = mx.random.uniform(shape=[start_indices.shape[0], 5])
59
+
60
+ node_sequence = random_walk(
61
+ row_ptr_mlx, col_mlx, start_indices, rand_data, 5, stream=mx.cpu
62
+ )
50
63
  ```
51
64
 
52
65
  ## TODO
@@ -0,0 +1,81 @@
1
+ #include <nanobind/nanobind.h>
2
+ #include <nanobind/stl/variant.h>
3
+ #include <random_walks/RandomWalk.h>
4
+ #include <random_walks/BiasedRandomWalk.h>
5
+
6
+ namespace nb = nanobind;
7
+ using namespace nb::literals;
8
+ using namespace mlx::core;
9
+
10
+ NB_MODULE(_ext, m){
11
+
12
+ m.def(
13
+ "random_walk",
14
+ [](const mx::array& rowptr,
15
+ const mx::array& col,
16
+ const mx::array& start,
17
+ const mx::array& rand,
18
+ int walk_length,
19
+ nb::object stream = nb::none()) {
20
+
21
+ // call the real C++ implementation
22
+ auto outs = mlx_random_walk::random_walk(
23
+ rowptr, col, start, rand, walk_length,
24
+ stream.is_none() ? mx::StreamOrDevice{}
25
+ : nb::cast<mx::StreamOrDevice>(stream));
26
+
27
+ // vector -> tuple (move avoids a copy)
28
+ return nb::make_tuple(std::move(outs[0]), std::move(outs[1]));
29
+ },
30
+ "rowptr"_a, "col"_a, "start"_a, "rand"_a, "walk_length"_a,
31
+ nb::kw_only(), "stream"_a = nb::none(),
32
+ R"(
33
+ Uniform random walks.
34
+
35
+ Returns:
36
+ (nodes, edges) tuple of arrays
37
+ )",
38
+ nb::rv_policy::move);
39
+
40
+ m.def(
41
+ "rejection_sampling",
42
+ [](const mx::array& rowptr,
43
+ const mx::array& col,
44
+ const mx::array& start,
45
+ int walk_length,
46
+ float p,
47
+ float q,
48
+ nb::object stream = nb::none()
49
+ ){
50
+ auto outs = mlx_biased_random_walk::rejection_sampling(
51
+ rowptr, col, start, walk_length, p, q,
52
+ stream.is_none() ? mx::StreamOrDevice{}
53
+ : nb::cast<mx::StreamOrDevice>(stream));
54
+ return nb::make_tuple(std::move(outs[0]), std::move(outs[1]));
55
+ },
56
+ "rowptr"_a,
57
+ "col"_a,
58
+ "start"_a,
59
+ "walk_length"_a,
60
+ "p"_a,
61
+ "q"_a,
62
+ nb::kw_only(), "stream"_a = nb::none(),
63
+ R"(
64
+ Sample nodes from the graph by sampling neighbors based
65
+ on probablity p and q
66
+
67
+ Args:
68
+ rowptr (array): rowptr of graph in csr format.
69
+ col (array): edges in csr format.
70
+ start (array): starting node of graph from which
71
+ biased sampling will be performed.
72
+ walk_length (int) : walk length of random graph
73
+ p : Likelihood of immediately revisiting a node in the walk.
74
+ q : Control parameter to interpolate between
75
+ breadth-first strategy and depth-first strategy
76
+
77
+ Returns:
78
+ (nodes, edges) tuple of arrays
79
+ )",
80
+ nb::rv_policy::move);
81
+ }
@@ -1,7 +1,7 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: mlx_cluster
3
- Version: 0.0.4
4
- Summary: C++ and Metal extensions for MLX CTC Loss
3
+ Version: 0.0.5
4
+ Summary: C++ extension for generating random graphs
5
5
  Author-email: Vinay Pandya <vinayharshadpandya27@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/vinayhpandya/mlx_cluster
7
7
  Project-URL: Issues, https://github.com/vinayhpandya/mlx_cluster/Issues
@@ -9,17 +9,22 @@ Classifier: Development Status :: 3 - Alpha
9
9
  Classifier: Programming Language :: Python :: 3
10
10
  Classifier: Programming Language :: C++
11
11
  Classifier: License :: OSI Approved :: MIT License
12
- Classifier: Operating System :: OS Independent
12
+ Classifier: Operating System :: MacOS
13
13
  Requires-Python: >=3.8
14
14
  Description-Content-Type: text/markdown
15
15
  License-File: LICENSE
16
16
  Provides-Extra: dev
17
17
  Provides-Extra: test
18
- Requires-Dist: mlx_graphs==0.0.7; extra == "test"
19
- Requires-Dist: torch==2.2.0; extra == "test"
20
- Requires-Dist: mlx>=0.17.0; extra == "test"
18
+ Requires-Dist: mlx-graphs>=0.0.8; extra == "test"
19
+ Requires-Dist: torch>=2.2.0; extra == "test"
20
+ Requires-Dist: mlx>=0.26.0; extra == "test"
21
21
  Requires-Dist: pytest==7.4.4; extra == "test"
22
- Requires-Dist: scipy==1.12.0; extra == "test"
22
+ Requires-Dist: scipy>=1.13.0; extra == "test"
23
+ Requires-Dist: requests==2.31.0; extra == "test"
24
+ Requires-Dist: fsspec[http]==2024.2.0; extra == "test"
25
+ Requires-Dist: tqdm==4.66.1; extra == "test"
26
+ Dynamic: license-file
27
+ Dynamic: requires-python
23
28
 
24
29
  # mlx_cluster
25
30
 
@@ -52,24 +57,37 @@ for testing purposes you need to have `mlx-graphs` and `torch_geometric` instal
52
57
 
53
58
 
54
59
  ```
55
- from mlx_graphs.utils.sorting import sort_edge_index
56
- from mlx_graphs.loaders import Dataloader
57
- from mlx_graphs_extension import random_walk
60
+ # Can also use mlx for generating starting indices
61
+ import torch
62
+ from torch.utils.data import DataLoader
63
+
64
+ loader = DataLoader(range(2708), batch_size=2000)
65
+ start_indices = next(iter(loader))
58
66
 
59
67
 
68
+ from mlx_graphs.datasets import PlanetoidDataset
69
+ from mlx_graphs.utils.sorting import sort_edge_index
70
+ from torch.utils.data import DataLoader
71
+ from mlx_cluster import random_walk
72
+
60
73
  cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
61
- start = mx.arange(0, 1000)
62
- start_time = time.time()
74
+ # For some reason int_64t and int_32t are not compatible
63
75
  edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
64
- num_nodes = cora_dataset.graphs[0].num_nodes
76
+
77
+ # Convert edge index into a CSR matrix
65
78
  sorted_edge_index = sort_edge_index(edge_index=edge_index)
66
79
  row_mlx = sorted_edge_index[0][0]
67
80
  col_mlx = sorted_edge_index[0][1]
68
- unique_vals, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
81
+ _, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
69
82
  cum_sum_mlx = counts_mlx.cumsum()
70
- rand = mx.random.uniform(shape=[start.shape[0], 100])
71
83
  row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
72
- random_walk(row_ptr_mlx, col_mlx, start, rand, 1000, stream = mx.gpu)
84
+ start_indices = mx.array(start_indices.numpy())
85
+
86
+ rand_data = mx.random.uniform(shape=[start_indices.shape[0], 5])
87
+
88
+ node_sequence = random_walk(
89
+ row_ptr_mlx, col_mlx, start_indices, rand_data, 5, stream=mx.cpu
90
+ )
73
91
  ```
74
92
 
75
93
  ## TODO
@@ -6,9 +6,6 @@ bindings.cpp
6
6
  pyproject.toml
7
7
  setup.py
8
8
  mlx_cluster/__init__.py
9
- mlx_cluster/_ext.cpython-311-darwin.so
10
- mlx_cluster/libmlx.dylib
11
- mlx_cluster/libmlx_cluster.dylib
12
9
  mlx_cluster/mlx_cluster.metallib
13
10
  mlx_cluster.egg-info/PKG-INFO
14
11
  mlx_cluster.egg-info/SOURCES.txt
@@ -0,0 +1,12 @@
1
+
2
+ [dev]
3
+
4
+ [test]
5
+ mlx-graphs>=0.0.8
6
+ torch>=2.2.0
7
+ mlx>=0.26.0
8
+ pytest==7.4.4
9
+ scipy>=1.13.0
10
+ requests==2.31.0
11
+ fsspec[http]==2024.2.0
12
+ tqdm==4.66.1
@@ -1,41 +1,48 @@
1
1
  [project]
2
2
  name = "mlx_cluster"
3
- version = "0.0.4"
3
+ version = "0.0.5"
4
4
  authors = [
5
5
  { name = "Vinay Pandya", email = "vinayharshadpandya27@gmail.com" },
6
6
  ]
7
- description = "C++ and Metal extensions for MLX CTC Loss"
7
+ description = "C++ extension for generating random graphs"
8
8
  readme = "README.md"
9
- requires-python = ">=3.8"
9
+ requires-python = ">=3.10"
10
10
  classifiers = [
11
11
  "Development Status :: 3 - Alpha",
12
12
  "Programming Language :: Python :: 3",
13
13
  "Programming Language :: C++",
14
14
  "License :: OSI Approved :: MIT License",
15
- "Operating System :: OS Independent",
15
+ "Operating System :: MacOS",
16
16
  ]
17
17
 
18
18
  [project.optional-dependencies]
19
19
  dev = []
20
20
  test = [
21
- "mlx_graphs==0.0.7",
22
- "torch==2.2.0",
23
- "mlx>=0.17.0",
21
+ "mlx-graphs>=0.0.8",
22
+ "torch>=2.2.0",
23
+ "mlx>=0.26.0",
24
24
  "pytest==7.4.4",
25
- "scipy==1.12.0",
25
+ "scipy>=1.13.0",
26
+ "requests==2.31.0",
27
+ "fsspec[http]==2024.2.0",
28
+ "tqdm==4.66.1",
26
29
  ]
30
+
27
31
  [project.urls]
28
32
  Homepage = "https://github.com/vinayhpandya/mlx_cluster"
29
33
  Issues = "https://github.com/vinayhpandya/mlx_cluster/Issues"
30
34
 
35
+ [tool.pytest.ini_options]
36
+ addopts = "-ra"
37
+ markers = [
38
+ "slow: marks tests that download data, compile kernels, or are otherwise time-consuming (deselect with -m 'not slow')",
39
+ ]
31
40
 
32
41
  [build-system]
33
42
  requires = [
43
+ "nanobind==2.4.0",
34
44
  "setuptools>=42",
35
45
  "cmake>=3.24",
36
- "mlx==0.18.0",
37
- "nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
46
+ "mlx>=0.26.0",
38
47
  ]
39
-
40
-
41
- build-backend = "setuptools.build_meta"
48
+ build-backend = "setuptools.build_meta"
@@ -15,7 +15,8 @@
15
15
  #endif
16
16
  #include "random_walks/BiasedRandomWalk.h"
17
17
 
18
- namespace mlx::core {
18
+
19
+ namespace mlx_biased_random_walk {
19
20
 
20
21
  bool inline is_neighbor(const int64_t *rowptr, const int64_t *col, int64_t v,
21
22
  int64_t w) {
@@ -27,21 +28,20 @@ namespace mlx::core {
27
28
  return false;
28
29
  }
29
30
 
30
- void BiasedRandomWalk::eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) {
31
+ void BiasedRandomWalk::eval_cpu(const std::vector<mx::array>& inputs, std::vector<mx::array>& outputs) {
31
32
  auto& rowptr = inputs[0];
32
33
  auto& col = inputs[1];
33
34
  auto& start = inputs[2];
34
35
  auto& rand = inputs[3];
35
36
  int numel = start.size();
36
-
37
+ std::cout<<"Inside biased random walk"<<std::endl;
37
38
  // Initialize outputs
38
39
  assert(outputs.size() == 2);
39
40
  // Allocate memory for outputs if not already allocated
40
- outputs[0].set_data(allocator::malloc_or_wait(numel*(walk_length_+1)*sizeof(int64_t)));
41
- outputs[1].set_data(allocator::malloc_or_wait(numel*walk_length_*sizeof(int64_t)));
41
+ outputs[0].set_data(mx::allocator::malloc(numel*(walk_length_+1)*sizeof(int64_t)));
42
+ outputs[1].set_data(mx::allocator::malloc(numel*walk_length_*sizeof(int64_t)));
42
43
  auto& n_out = outputs[0];
43
44
  auto& e_out = outputs[1];
44
-
45
45
  auto* n_out_ptr = n_out.data<int64_t>();
46
46
  auto* e_out_ptr = e_out.data<int64_t>();
47
47
  auto* start_values = start.data<int64_t>();
@@ -53,7 +53,7 @@ namespace mlx::core {
53
53
  double prob_0 = 1. / p_ / max_prob;
54
54
  double prob_1 = 1. / max_prob;
55
55
  double prob_2 = 1. / q_ / max_prob;
56
-
56
+
57
57
  for (int64_t n = 0; n < numel; n++) {
58
58
  int64_t t = start_values[n], v, x, e_cur, row_start, row_end;
59
59
  n_out_ptr[n * (walk_length_ + 1)] = t;
@@ -91,7 +91,6 @@ namespace mlx::core {
91
91
  break;
92
92
  }
93
93
  }
94
-
95
94
  n_out_ptr[n * (walk_length_ + 1) + (l + 1)] = x;
96
95
  e_out_ptr[n * walk_length_ + l] = e_cur;
97
96
  t = v;
@@ -101,9 +100,9 @@ namespace mlx::core {
101
100
 
102
101
  };
103
102
 
104
- std::vector<array> BiasedRandomWalk::jvp(
105
- const std::vector<array>& primals,
106
- const std::vector<array>& tangents,
103
+ std::vector<mx::array> BiasedRandomWalk::jvp(
104
+ const std::vector<mx::array>& primals,
105
+ const std::vector<mx::array>& tangents,
107
106
  const std::vector<int>& argnums)
108
107
  {
109
108
  // Random walk is not differentiable, so we return zero tangents
@@ -121,8 +120,8 @@ namespace mlx::core {
121
120
  // int numel = start.size();
122
121
 
123
122
  // assert(outputs.size() == 2);
124
- // outputs[0].set_data(allocator::malloc_or_wait(numel * (walk_length_ + 1) * sizeof(int64_t)));
125
- // outputs[1].set_data(allocator::malloc_or_wait(numel * walk_length_ * sizeof(int64_t)));
123
+ // outputs[0].set_data(allocator::malloc(numel * (walk_length_ + 1) * sizeof(int64_t)));
124
+ // outputs[1].set_data(allocator::malloc(numel * walk_length_ * sizeof(int64_t)));
126
125
  // std::cout<<"after setting data"<<std::endl;
127
126
  // auto& s = stream();
128
127
  // auto& d = metal::device(s.device);
@@ -148,47 +147,42 @@ namespace mlx::core {
148
147
  // }
149
148
  // #endif
150
149
  void BiasedRandomWalk::eval_gpu(
151
- const std::vector<array>& inputs, std::vector<array>& outputs
150
+ const std::vector<mx::array>& inputs, std::vector<mx::array>& outputs
152
151
  )
153
152
  {
154
153
  throw std::runtime_error("Random walk has no GPU implementation.");
155
154
  }
156
- std::vector<array> BiasedRandomWalk::vjp(
157
- const std::vector<array>& primals,
158
- const std::vector<array>& cotangents,
155
+ std::vector<mx::array> BiasedRandomWalk::vjp(
156
+ const std::vector<mx::array>& primals,
157
+ const std::vector<mx::array>& cotangents,
159
158
  const std::vector<int>& argnums,
160
- const std::vector<array>& outputs)
159
+ const std::vector<mx::array>& outputs)
161
160
  {
162
161
  // Random walk is not differentiable, so we return zero gradients
163
162
  throw std::runtime_error("Random walk has no JVP implementation.");
164
163
  }
165
164
 
166
- std::pair<std::vector<array>, std::vector<int>> BiasedRandomWalk::vmap(
167
- const std::vector<array>& inputs,
165
+ std::pair<std::vector<mx::array>, std::vector<int>> BiasedRandomWalk::vmap(
166
+ const std::vector<mx::array>& inputs,
168
167
  const std::vector<int>& axes)
169
168
  {
170
169
  throw std::runtime_error("vmap not implemented for biasedRandomWalk");
171
170
  }
172
171
 
173
- bool BiasedRandomWalk::is_equivalent(const Primitive& other) const
174
- {
175
- throw std::runtime_error("biased Random walk has no GPU implementation.");
176
- }
177
-
178
- std::vector<std::vector<int>> BiasedRandomWalk::output_shapes(const std::vector<array>& inputs)
172
+ bool BiasedRandomWalk::is_equivalent(const mx::Primitive& other) const
179
173
  {
180
174
  throw std::runtime_error("biased Random walk has no GPU implementation.");
181
175
  }
182
176
 
183
- array rejection_sampling(const array& rowptr, const array& col, const array& start, int walk_length, const double p,
184
- const double q, StreamOrDevice s)
177
+ std::vector<mx::array> rejection_sampling(const mx::array& rowptr, const mx::array& col, const mx::array& start, int walk_length, const double p,
178
+ const double q, mx::StreamOrDevice s)
185
179
  {
186
180
  int nodes = start.size();
187
181
  auto primitive = std::make_shared<BiasedRandomWalk>(to_stream(s), walk_length, p, q);
188
- return array::make_arrays({{nodes,walk_length+1},{nodes, walk_length}},
182
+ return mx::array::make_arrays({{nodes,walk_length+1},{nodes, walk_length}},
189
183
  {rowptr.dtype(), rowptr.dtype()},
190
184
  primitive,
191
185
  {rowptr, col, start}
192
- )[0];
186
+ );
193
187
  }
194
188
  }
@@ -0,0 +1,65 @@
1
+ #pragma once
2
+
3
+ #include <mlx/array.h>
4
+ #include <mlx/ops.h>
5
+ #include <mlx/primitives.h>
6
+
7
+ namespace mx = mlx::core;
8
+ namespace mlx_biased_random_walk{
9
+
10
+ class BiasedRandomWalk : public mx::Primitive {
11
+ public:
12
+ BiasedRandomWalk(mx::Stream stream, int walk_length, double p, double q)
13
+ : mx::Primitive(stream), walk_length_(walk_length), p_(p), q_(q) {}
14
+ void eval_cpu(const std::vector<mx::array>& inputs, std::vector<mx::array>& outputs)
15
+ override;
16
+ void eval_gpu(const std::vector<mx::array>& inputs, std::vector<mx::array>& outputs)
17
+ override;
18
+
19
+ /** The Jacobian-vector product. */
20
+ std::vector<mx::array> jvp(
21
+ const std::vector<mx::array>& primals,
22
+ const std::vector<mx::array>& tangents,
23
+ const std::vector<int>& argnums) override;
24
+
25
+ /** The vector-Jacobian product. */
26
+ std::vector<mx::array> vjp(
27
+ const std::vector<mx::array>& primals,
28
+ const std::vector<mx::array>& cotangents,
29
+ const std::vector<int>& argnums,
30
+ const std::vector<mx::array>& outputs) override;
31
+
32
+ /**
33
+ * The primitive must know how to vectorize itself across
34
+ * the given axes. The output is a pair containing the array
35
+ * representing the vectorized computation and the axis which
36
+ * corresponds to the output vectorized dimension.
37
+ */
38
+ std::pair<std::vector<mx::array>, std::vector<int>> vmap(
39
+ const std::vector<mx::array>& inputs,
40
+ const std::vector<int>& axes) override;
41
+
42
+ /** Print the primitive. */
43
+ virtual const char* name() const override {
44
+ return "biased random walk implementation";
45
+ }
46
+
47
+ /** Equivalence check **/
48
+ bool is_equivalent(const mx::Primitive& other) const override;
49
+
50
+ private:
51
+ int walk_length_;
52
+ double p_;
53
+ double q_;
54
+
55
+ };
56
+
57
+ std::vector<mx::array> rejection_sampling(const mx::array& rowptr,
58
+ const mx::array& col,
59
+ const mx::array& start,
60
+ int walk_length,
61
+ const double p,
62
+ const double q,
63
+ mx::StreamOrDevice s = {});
64
+
65
+ };
@@ -1,7 +1,7 @@
1
1
  #include <cassert>
2
2
  #include <iostream>
3
3
  #include <sstream>
4
-
4
+ #include <dlfcn.h>
5
5
  #include "mlx/backend/common/copy.h"
6
6
  #include "mlx/backend/common/utils.h"
7
7
  #include "mlx/utils.h"
@@ -14,19 +14,31 @@
14
14
  #endif
15
15
  #include "random_walks/RandomWalk.h"
16
16
 
17
- namespace mlx::core {
18
- void RandomWalk::eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) {
17
+
18
+ namespace mlx_random_walk {
19
+ std::string current_binary_dir() {
20
+ static std::string binary_dir = []() {
21
+ Dl_info info;
22
+ if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
23
+ throw std::runtime_error("Unable to get current binary dir.");
24
+ }
25
+ return std::filesystem::path(info.dli_fname).parent_path().string();
26
+ }();
27
+ return binary_dir;
28
+ }
29
+
30
+ void RandomWalk::eval_cpu(const std::vector<mx::array>& inputs, std::vector<mx::array>& outputs) {
19
31
  auto& rowptr = inputs[0];
20
32
  auto& col = inputs[1];
21
33
  auto& start = inputs[2];
22
34
  auto& rand = inputs[3];
23
35
  int numel = start.size();
24
-
36
+ std::cout<<"Its really inside cpu"<<std::endl;
25
37
  // Initialize outputs
26
38
  assert(outputs.size() == 2);
27
39
  // Allocate memory for outputs if not already allocated
28
- outputs[0].set_data(allocator::malloc_or_wait(numel*(walk_length_+1)*sizeof(int64_t)));
29
- outputs[1].set_data(allocator::malloc_or_wait(numel*walk_length_*sizeof(int64_t)));
40
+ outputs[0].set_data(mx::allocator::malloc(numel*(walk_length_+1)*sizeof(int64_t)));
41
+ outputs[1].set_data(mx::allocator::malloc(numel*walk_length_*sizeof(int64_t)));
30
42
  auto& n_out = outputs[0];
31
43
  auto& e_out = outputs[1];
32
44
 
@@ -37,7 +49,6 @@ namespace mlx::core {
37
49
  auto* col_values = col.data<int64_t>();
38
50
  auto* rand_values = rand.data<float>();
39
51
 
40
- std::cout<<"After evaluating outputs"<<std::endl;
41
52
  for (int64_t n = 0; n < numel; n++) {
42
53
  int64_t n_cur = start_values[n];
43
54
  n_out_ptr[n * (walk_length_ + 1)] = n_cur;
@@ -61,9 +72,9 @@ namespace mlx::core {
61
72
 
62
73
  };
63
74
 
64
- std::vector<array> RandomWalk::jvp(
65
- const std::vector<array>& primals,
66
- const std::vector<array>& tangents,
75
+ std::vector<mx::array> RandomWalk::jvp(
76
+ const std::vector<mx::array>& primals,
77
+ const std::vector<mx::array>& tangents,
67
78
  const std::vector<int>& argnums)
68
79
  {
69
80
  // Random walk is not differentiable, so we return zero tangents
@@ -71,8 +82,8 @@ namespace mlx::core {
71
82
  }
72
83
  #ifdef _METAL_
73
84
  void RandomWalk::eval_gpu(
74
- const std::vector<array>& inputs,
75
- std::vector<array>& outputs
85
+ const std::vector<mx::array>& inputs,
86
+ std::vector<mx::array>& outputs
76
87
  ){
77
88
  auto& rowptr = inputs[0];
78
89
  auto& col = inputs[1];
@@ -81,17 +92,16 @@ void RandomWalk::eval_gpu(
81
92
  int numel = start.size();
82
93
 
83
94
  assert(outputs.size() == 2);
84
- outputs[0].set_data(allocator::malloc_or_wait(numel * (walk_length_ + 1) * sizeof(int64_t)));
85
- outputs[1].set_data(allocator::malloc_or_wait(numel * walk_length_ * sizeof(int64_t)));
86
- std::cout<<"after setting data"<<std::endl;
95
+ outputs[0].set_data(mx::allocator::malloc(numel * (walk_length_ + 1) * sizeof(int64_t)));
96
+ outputs[1].set_data(mx::allocator::malloc(numel * walk_length_ * sizeof(int64_t)));
87
97
  auto& s = stream();
88
- auto& d = metal::device(s.device);
89
-
90
- d.register_library("mlx_cluster");
91
- auto kernel = d.get_kernel("random_walk", "mlx_cluster");
98
+ auto& d = mx::metal::device(s.device);
99
+ std::cout<<"Its really inside gpu"<<std::endl;
100
+ auto lib = d.get_library("mlx_cluster", current_binary_dir());
101
+ auto kernel = d.get_kernel("random_walk", lib);
92
102
 
93
103
  auto& compute_encoder = d.get_command_encoder(s.index);
94
- compute_encoder->setComputePipelineState(kernel);
104
+ compute_encoder.set_compute_pipeline_state(kernel);
95
105
 
96
106
  compute_encoder.set_input_array(rowptr, 0);
97
107
  compute_encoder.set_input_array(col, 1);
@@ -99,51 +109,46 @@ void RandomWalk::eval_gpu(
99
109
  compute_encoder.set_input_array(rand, 3);
100
110
  compute_encoder.set_output_array(outputs[0], 4);
101
111
  compute_encoder.set_output_array(outputs[1], 5);
102
- compute_encoder->setBytes(&walk_length_, sizeof(int32), 6);
112
+ compute_encoder.set_bytes(&walk_length_, sizeof(walk_length_), 6);
103
113
 
104
114
  MTL::Size grid_size = MTL::Size(numel, 1, 1);
105
115
  MTL::Size thread_group_size = MTL::Size(kernel->maxTotalThreadsPerThreadgroup(), 1, 1);
106
116
 
107
- compute_encoder.dispatchThreads(grid_size, thread_group_size);
117
+ compute_encoder.dispatch_threads(grid_size, thread_group_size);
108
118
  }
109
119
  #endif
110
120
 
111
- std::vector<array> RandomWalk::vjp(
112
- const std::vector<array>& primals,
113
- const std::vector<array>& cotangents,
121
+ std::vector<mx::array> RandomWalk::vjp(
122
+ const std::vector<mx::array>& primals,
123
+ const std::vector<mx::array>& cotangents,
114
124
  const std::vector<int>& argnums,
115
- const std::vector<array>& outputs)
125
+ const std::vector<mx::array>& outputs)
116
126
  {
117
127
  // Random walk is not differentiable, so we return zero gradients
118
128
  throw std::runtime_error("Random walk has no GPU implementation.");
119
129
  }
120
130
 
121
- std::pair<std::vector<array>, std::vector<int>> RandomWalk::vmap(
122
- const std::vector<array>& inputs,
131
+ std::pair<std::vector<mx::array>, std::vector<int>> RandomWalk::vmap(
132
+ const std::vector<mx::array>& inputs,
123
133
  const std::vector<int>& axes)
124
134
  {
125
135
  throw std::runtime_error("vmap not implemented for RandomWalk");
126
136
  }
127
137
 
128
- bool RandomWalk::is_equivalent(const Primitive& other) const
129
- {
130
- throw std::runtime_error("Random walk has no GPU implementation.");
131
- }
132
-
133
- std::vector<std::vector<int>> RandomWalk::output_shapes(const std::vector<array>& inputs)
138
+ bool RandomWalk::is_equivalent(const mx::Primitive& other) const
134
139
  {
135
140
  throw std::runtime_error("Random walk has no GPU implementation.");
136
141
  }
137
142
 
138
- array random_walk(const array& rowptr, const array& col, const array& start, const array& rand, int walk_length, StreamOrDevice s)
143
+ std::vector<mx::array> random_walk(const mx::array& rowptr, const mx::array& col, const mx::array& start, const mx::array& rand, int walk_length, mx::StreamOrDevice s)
139
144
  {
140
145
  std::cout<<"Inside random walk"<<std::endl;
141
146
  int nodes = start.size();
142
147
  auto primitive = std::make_shared<RandomWalk>(walk_length, to_stream(s));
143
- return array::make_arrays({{nodes,walk_length+1},{nodes, walk_length}},
148
+ return mx::array::make_arrays({{nodes,walk_length+1},{nodes, walk_length}},
144
149
  {start.dtype(), start.dtype()},
145
150
  primitive,
146
151
  {rowptr, col, start, rand}
147
- )[0];
152
+ );
148
153
  }
149
154
  }
@@ -0,0 +1,62 @@
1
+ #pragma once
2
+
3
+ #include <mlx/array.h>
4
+ #include <mlx/ops.h>
5
+ #include <mlx/primitives.h>
6
+
7
+ namespace mx = mlx::core;
8
+ namespace mlx_random_walk{
9
+
10
+ class RandomWalk : public mx::Primitive {
11
+ public:
12
+ explicit RandomWalk(int walk_length, mx::Stream stream):
13
+ mx::Primitive(stream), walk_length_(walk_length) {};
14
+ void eval_cpu(const std::vector<mx::array>& inputs, std::vector<mx::array>& outputs)
15
+ override;
16
+ void eval_gpu(const std::vector<mx::array>& inputs, std::vector<mx::array>& outputs)
17
+ override;
18
+
19
+ /** The Jacobian-vector product. */
20
+ std::vector<mx::array> jvp(
21
+ const std::vector<mx::array>& primals,
22
+ const std::vector<mx::array>& tangents,
23
+ const std::vector<int>& argnums) override;
24
+
25
+ /** The vector-Jacobian product. */
26
+ std::vector<mx::array> vjp(
27
+ const std::vector<mx::array>& primals,
28
+ const std::vector<mx::array>& cotangents,
29
+ const std::vector<int>& argnums,
30
+ const std::vector<mx::array>& outputs) override;
31
+
32
+ /**
33
+ * The primitive must know how to vectorize itself across
34
+ * the given axes. The output is a pair containing the array
35
+ * representing the vectorized computation and the axis which
36
+ * corresponds to the output vectorized dimension.
37
+ */
38
+ std::pair<std::vector<mx::array>, std::vector<int>> vmap(
39
+ const std::vector<mx::array>& inputs,
40
+ const std::vector<int>& axes) override;
41
+
42
+ /** Print the primitive. */
43
+ virtual const char* name() const override {
44
+ return "Random walk implementation";
45
+ }
46
+
47
+ /** Equivalence check **/
48
+ bool is_equivalent(const mx::Primitive& other) const override;
49
+
50
+ private:
51
+ int walk_length_;
52
+
53
+ };
54
+
55
+ std::vector<mx::array> random_walk(const mx::array& rowptr,
56
+ const mx::array& col,
57
+ const mx::array& start,
58
+ const mx::array& rand,
59
+ int walk_length,
60
+ mx::StreamOrDevice s = {});
61
+
62
+ };
@@ -4,7 +4,7 @@ from mlx import extension
4
4
  if __name__ == "__main__":
5
5
  setup(
6
6
  name="mlx_cluster",
7
- version="0.0.4",
7
+ version="0.0.5",
8
8
  description="Sample C++ and Metal extensions for MLX primitives.",
9
9
  ext_modules=[extension.CMakeExtension("mlx_cluster._ext")],
10
10
  cmdclass={"build_ext": extension.CMakeBuild},
@@ -0,0 +1,72 @@
1
+ import mlx.core as mx
2
+ import numpy as np
3
+ import time
4
+
5
+ # Torch dataset
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+
9
+ loader = DataLoader(range(2708), batch_size=2000)
10
+ start_indices = next(iter(loader))
11
+
12
+
13
+ from mlx_graphs.datasets import PlanetoidDataset
14
+ from mlx_graphs.utils.sorting import sort_edge_index
15
+ from torch.utils.data import DataLoader
16
+ from mlx_cluster import random_walk
17
+ import pytest
18
+ import time
19
+
20
+ @pytest.mark.slow # give download/compile plenty of time on CI
21
+ def test_random_walk(tmp_path):
22
+ """
23
+ Runs 1 000 random walks of length 10 on the Cora graph and checks:
24
+ 1. output tensor shape == (num_start_nodes, walk_length + 1)
25
+ 2. all returned node indices are valid ( < num_nodes )
26
+ """
27
+
28
+ # ---------- Dataset (downloaded to the temp dir) ----------
29
+ data_dir = tmp_path / "mlx_datasets"
30
+ cora = PlanetoidDataset(name="cora", base_dir=data_dir)
31
+
32
+ edge_index = cora.graphs[0].edge_index.astype(mx.int64)
33
+
34
+ # CSR conversion
35
+ sorted_edge_index = sort_edge_index(edge_index=edge_index)
36
+ row = sorted_edge_index[0][0]
37
+ col = sorted_edge_index[0][1]
38
+ _, counts = np.unique(np.array(row, copy=False), return_counts=True)
39
+ row_ptr = mx.concatenate([mx.array([0]), mx.array(counts.cumsum())])
40
+
41
+ # pick 1 000 random start nodes
42
+ num_starts = 1_000
43
+ rng = np.random.default_rng(42)
44
+ start_idx = mx.array(rng.integers(low=0, high=row.max().item() + 1,
45
+ size=num_starts, dtype=np.int64))
46
+
47
+ # random numbers for the kernel (shape [num_starts, walk_length])
48
+ walk_len = 10
49
+ rand_data = mx.random.uniform(shape=[num_starts, walk_len])
50
+
51
+ # ---------- Warm-up ----------
52
+ mx.eval(row_ptr, col, start_idx, rand_data)
53
+
54
+ # ---------- Run kernel ----------
55
+ t0 = time.time()
56
+ node_seq = random_walk(row_ptr, col, start_idx, rand_data,
57
+ walk_len, stream=mx.cpu)
58
+ elapsed = time.time() - t0
59
+ print(f"Random-walk kernel took {elapsed:.3f} s")
60
+ print("Node sequence is ", node_seq[0])
61
+ # ---------- Assertions ----------
62
+ assert node_seq[0].shape == (num_starts, walk_len + 1)
63
+
64
+ # num_nodes = cora.graphs[0].num_nodes
65
+ # assert (node_seq < num_nodes).all().item(), \
66
+ # "Random walk produced invalid node indices"
67
+ t0 = time.time()
68
+ node_seq_gpu = random_walk(row_ptr, col, start_idx, rand_data,
69
+ walk_len, stream=mx.gpu)
70
+ elapsed = time.time() - t0
71
+ print(f"Random-walk kernel on gpu took {elapsed:.3f} s")
72
+ print("Node sequence is ", node_seq_gpu[0])
@@ -0,0 +1,62 @@
1
+ import mlx.core as mx
2
+ import numpy as np
3
+ import time
4
+ import pytest
5
+
6
+ # Torch dataset
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+
10
+ loader = DataLoader(range(2708), batch_size=2000)
11
+ start_indices = next(iter(loader))
12
+ # random_walks = torch.ops.torch_cluster.random_walk(
13
+ # row_ptr, col, start_indices, 5, 1.0, 3.0
14
+ # )
15
+
16
+ from mlx_graphs.datasets import PlanetoidDataset
17
+ from mlx_graphs.utils.sorting import sort_edge_index
18
+ from torch.utils.data import DataLoader
19
+ from mlx_cluster import rejection_sampling
20
+
21
+ @pytest.mark.slow # give download/compile plenty of time on CI
22
+ def test_random_walk(tmp_path):
23
+ """
24
+ Runs 1 000 random walks of length 10 on the Cora graph and checks:
25
+ 1. output tensor shape == (num_start_nodes, walk_length + 1)
26
+ 2. all returned node indices are valid ( < num_nodes )
27
+ """
28
+
29
+ # ---------- Dataset (downloaded to the temp dir) ----------
30
+ data_dir = tmp_path / "mlx_datasets"
31
+ cora = PlanetoidDataset(name="cora", base_dir=data_dir)
32
+
33
+ edge_index = cora.graphs[0].edge_index.astype(mx.int64)
34
+
35
+ # CSR conversion
36
+ sorted_edge_index = sort_edge_index(edge_index=edge_index)
37
+ row = sorted_edge_index[0][0]
38
+ col = sorted_edge_index[0][1]
39
+ _, counts = np.unique(np.array(row, copy=False), return_counts=True)
40
+ row_ptr = mx.concatenate([mx.array([0]), mx.array(counts.cumsum())])
41
+
42
+ # pick 1 000 random start nodes
43
+ num_starts = 1_000
44
+ rng = np.random.default_rng(42)
45
+ start_idx = mx.array(rng.integers(low=0, high=row.max().item() + 1,
46
+ size=num_starts, dtype=np.int64))
47
+
48
+ # random numbers for the kernel (shape [num_starts, walk_length])
49
+ walk_len = 10
50
+ rand_data = mx.random.uniform(shape=[num_starts, walk_len])
51
+
52
+ # ---------- Warm-up ----------
53
+ mx.eval(row_ptr, col, start_idx, rand_data)
54
+
55
+ # ---------- Run kernel ----------
56
+ t0 = time.time()
57
+ node_seq = rejection_sampling(row_ptr, col, start_idx, walk_len, 1.0, 3.0, stream=mx.cpu)
58
+ elapsed = time.time() - t0
59
+ print(f"Random-walk kernel took {elapsed:.3f} s")
60
+ print("Node sequence is ", node_seq)
61
+ # ---------- Assertions ----------
62
+ assert node_seq[0].shape == (num_starts, walk_len + 1)
@@ -1,65 +0,0 @@
1
- #include <nanobind/nanobind.h>
2
- #include <nanobind/stl/variant.h>
3
- #include <random_walks/RandomWalk.h>
4
- #include <random_walks/BiasedRandomWalk.h>
5
-
6
- namespace nb = nanobind;
7
- using namespace nb::literals;
8
- using namespace mlx::core;
9
-
10
- NB_MODULE(_ext, m){
11
-
12
- m.def(
13
- "random_walk",
14
- &random_walk,
15
- "rowptr"_a,
16
- "col"_a,
17
- "start"_a,
18
- "rand"_a,
19
- "walk_length"_a,
20
- nb::kw_only(),
21
- "stream"_a = nb::none(),
22
- R"(
23
- uniformly sample a graph
24
-
25
-
26
- Args:
27
- rowptr (array): rowptr of graph in csr format.
28
- col (array): edges in csr format.
29
- walk_length (int) : walk length of random graph
30
-
31
- Returns:
32
- array: consisting of nodes visited on random walk
33
- )");
34
-
35
- m.def(
36
- "rejection_sampling",
37
- &rejection_sampling,
38
- "rowptr"_a,
39
- "col"_a,
40
- "start"_a,
41
- "walk_length"_a,
42
- "p"_a,
43
- "q"_a,
44
- nb::kw_only(),
45
- "stream"_a = nb::none(),
46
- R"(
47
- Sample nodes from the graph by sampling neighbors based
48
- on probablity p and q
49
-
50
-
51
- Args:
52
- rowptr (array): rowptr of graph in csr format.
53
- col (array): edges in csr format.
54
- start (array): starting node of graph from which
55
- biased sampling will be performed.
56
- walk_length (int) : walk length of random graph
57
- p : Likelihood of immediately revisiting a node in the walk.
58
- q : Control parameter to interpolate between
59
- breadth-first strategy and depth-first strategy
60
-
61
- Returns:
62
- array: consisting of nodes visited on random walk
63
- )");
64
- }
65
-
Binary file
@@ -1,9 +0,0 @@
1
-
2
- [dev]
3
-
4
- [test]
5
- mlx_graphs==0.0.7
6
- torch==2.2.0
7
- mlx>=0.17.0
8
- pytest==7.4.4
9
- scipy==1.12.0
@@ -1,66 +0,0 @@
1
- #pragma once
2
-
3
- #include <mlx/array.h>
4
- #include <mlx/ops.h>
5
- #include <mlx/primitives.h>
6
-
7
- namespace mlx::core{
8
-
9
- class BiasedRandomWalk : public Primitive {
10
- public:
11
- BiasedRandomWalk(Stream stream, int walk_length, double p, double q)
12
- : Primitive(stream), walk_length_(walk_length), p_(p), q_(q) {}
13
- void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
14
- override;
15
- void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
16
- override;
17
-
18
- /** The Jacobian-vector product. */
19
- std::vector<array> jvp(
20
- const std::vector<array>& primals,
21
- const std::vector<array>& tangents,
22
- const std::vector<int>& argnums) override;
23
-
24
- /** The vector-Jacobian product. */
25
- std::vector<array> vjp(
26
- const std::vector<array>& primals,
27
- const std::vector<array>& cotangents,
28
- const std::vector<int>& argnums,
29
- const std::vector<array>& outputs) override;
30
-
31
- /**
32
- * The primitive must know how to vectorize itself across
33
- * the given axes. The output is a pair containing the array
34
- * representing the vectorized computation and the axis which
35
- * corresponds to the output vectorized dimension.
36
- */
37
- std::pair<std::vector<array>, std::vector<int>> vmap(
38
- const std::vector<array>& inputs,
39
- const std::vector<int>& axes) override;
40
-
41
- /** Print the primitive. */
42
- void print(std::ostream& os) override {
43
- os << "biased random walk implementation";
44
- }
45
-
46
- /** Equivalence check **/
47
- bool is_equivalent(const Primitive& other) const override;
48
-
49
- std::vector<std::vector<int>> output_shapes(const std::vector<array>& inputs) override;
50
-
51
- private:
52
- int walk_length_;
53
- double p_;
54
- double q_;
55
-
56
- };
57
-
58
- array rejection_sampling(const array& rowptr,
59
- const array& col,
60
- const array& start,
61
- int walk_length,
62
- const double p,
63
- const double q,
64
- StreamOrDevice s = {});
65
-
66
- };
@@ -1,63 +0,0 @@
1
- #pragma once
2
-
3
- #include <mlx/array.h>
4
- #include <mlx/ops.h>
5
- #include <mlx/primitives.h>
6
-
7
- namespace mlx::core{
8
-
9
- class RandomWalk : public Primitive {
10
- public:
11
- explicit RandomWalk(int walk_length, Stream stream):
12
- Primitive(stream), walk_length_(walk_length) {};
13
- void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
14
- override;
15
- void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
16
- override;
17
-
18
- /** The Jacobian-vector product. */
19
- std::vector<array> jvp(
20
- const std::vector<array>& primals,
21
- const std::vector<array>& tangents,
22
- const std::vector<int>& argnums) override;
23
-
24
- /** The vector-Jacobian product. */
25
- std::vector<array> vjp(
26
- const std::vector<array>& primals,
27
- const std::vector<array>& cotangents,
28
- const std::vector<int>& argnums,
29
- const std::vector<array>& outputs) override;
30
-
31
- /**
32
- * The primitive must know how to vectorize itself across
33
- * the given axes. The output is a pair containing the array
34
- * representing the vectorized computation and the axis which
35
- * corresponds to the output vectorized dimension.
36
- */
37
- std::pair<std::vector<array>, std::vector<int>> vmap(
38
- const std::vector<array>& inputs,
39
- const std::vector<int>& axes) override;
40
-
41
- /** Print the primitive. */
42
- void print(std::ostream& os) override {
43
- os << "Random walk implementation";
44
- }
45
-
46
- /** Equivalence check **/
47
- bool is_equivalent(const Primitive& other) const override;
48
-
49
- std::vector<std::vector<int>> output_shapes(const std::vector<array>& inputs) override;
50
-
51
- private:
52
- int walk_length_;
53
-
54
- };
55
-
56
- array random_walk(const array& rowptr,
57
- const array& col,
58
- const array& start,
59
- const array& rand,
60
- int walk_length,
61
- StreamOrDevice s = {});
62
-
63
- };
@@ -1,38 +0,0 @@
1
- import mlx.core as mx
2
- import numpy as np
3
- import time
4
-
5
- # Torch dataset
6
- import torch
7
- from torch.utils.data import DataLoader
8
-
9
- loader = DataLoader(range(2708), batch_size=2000)
10
- start_indices = next(iter(loader))
11
-
12
-
13
- from mlx_graphs.datasets import PlanetoidDataset
14
- from mlx_graphs.utils.sorting import sort_edge_index
15
- from torch.utils.data import DataLoader
16
- from mlx_cluster import random_walk
17
-
18
- cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
19
- # For some reason int_64t and int_32t are not compatible
20
- edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
21
-
22
- # Convert edge index into a CSR matrix
23
- sorted_edge_index = sort_edge_index(edge_index=edge_index)
24
- row_mlx = sorted_edge_index[0][0]
25
- col_mlx = sorted_edge_index[0][1]
26
- _, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
27
- cum_sum_mlx = counts_mlx.cumsum()
28
- row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
29
- start_indices = mx.array(start_indices.numpy())
30
-
31
- rand_data = mx.random.uniform(shape=[start_indices.shape[0], 5])
32
- start_time = time.time()
33
-
34
- node_sequence = random_walk(
35
- row_ptr_mlx, col_mlx, start_indices, rand_data, 5, stream=mx.cpu
36
- )
37
- print("Time taken to complete 1000 random walks : ", time.time() - start_time)
38
- print("MLX random walks are", node_sequence)
@@ -1,35 +0,0 @@
1
- import mlx.core as mx
2
- import numpy as np
3
- import time
4
-
5
- # Torch dataset
6
- import torch
7
- from torch.utils.data import DataLoader
8
-
9
- loader = DataLoader(range(2708), batch_size=2000)
10
- start_indices = next(iter(loader))
11
- # random_walks = torch.ops.torch_cluster.random_walk(
12
- # row_ptr, col, start_indices, 5, 1.0, 3.0
13
- # )
14
-
15
- from mlx_graphs.datasets import PlanetoidDataset
16
- from mlx_graphs.utils.sorting import sort_edge_index
17
- from torch.utils.data import DataLoader
18
- from mlx_cluster import rejection_sampling
19
-
20
- cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
21
- edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
22
- sorted_edge_index = sort_edge_index(edge_index=edge_index)
23
- row_mlx = sorted_edge_index[0][0]
24
- col_mlx = sorted_edge_index[0][1]
25
- _, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
26
- cum_sum_mlx = counts_mlx.cumsum()
27
- row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
28
- start_indices = mx.array(start_indices.numpy())
29
- rand_data = mx.random.uniform(shape=[start_indices.shape[0], 5])
30
- start_time = time.time()
31
- node_sequence = rejection_sampling(
32
- row_ptr_mlx, col_mlx, start_indices, 5, 1.0, 3.0, stream=mx.cpu
33
- )
34
- print("Time taken to complete 1000 random walks : ", time.time() - start_time)
35
- print(node_sequence)
File without changes
File without changes
File without changes
File without changes