mlx-cluster 0.0.5__tar.gz → 0.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/CMakeLists.txt +11 -1
- mlx_cluster-0.0.6/PKG-INFO +243 -0
- mlx_cluster-0.0.6/README.md +206 -0
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/bindings.cpp +57 -5
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/mlx_cluster/__init__.py +1 -0
- mlx_cluster-0.0.6/mlx_cluster.egg-info/PKG-INFO +243 -0
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/mlx_cluster.egg-info/SOURCES.txt +3 -0
- mlx_cluster-0.0.6/mlx_cluster.egg-info/requires.txt +22 -0
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/pyproject.toml +12 -1
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/random_walks/BiasedRandomWalk.cpp +3 -5
- mlx_cluster-0.0.6/random_walks/NeighborSample.cpp +127 -0
- mlx_cluster-0.0.6/random_walks/NeighborSample.h +10 -0
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/random_walks/RandomWalk.cpp +4 -5
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/setup.py +1 -1
- mlx_cluster-0.0.6/tests/test_neighbor_sample.py +300 -0
- mlx_cluster-0.0.5/PKG-INFO +0 -101
- mlx_cluster-0.0.5/README.md +0 -73
- mlx_cluster-0.0.5/mlx_cluster.egg-info/PKG-INFO +0 -101
- mlx_cluster-0.0.5/mlx_cluster.egg-info/requires.txt +0 -12
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/LICENSE +0 -0
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/MANIFEST.in +0 -0
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/mlx_cluster/mlx_cluster.metallib +0 -0
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/mlx_cluster.egg-info/dependency_links.txt +0 -0
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/mlx_cluster.egg-info/not-zip-safe +0 -0
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/mlx_cluster.egg-info/top_level.txt +0 -0
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/random_walks/BiasedRandomWalk.h +0 -0
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/random_walks/RandomWalk.h +0 -0
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/random_walks/random_walk.metal +0 -0
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/setup.cfg +0 -0
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/tests/test_random_walk.py +0 -0
- {mlx_cluster-0.0.5 → mlx_cluster-0.0.6}/tests/test_rejection_sampling.py +0 -0
|
@@ -17,6 +17,15 @@ execute_process(
|
|
|
17
17
|
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
|
18
18
|
find_package(nanobind CONFIG REQUIRED)
|
|
19
19
|
|
|
20
|
+
include(FetchContent)
|
|
21
|
+
|
|
22
|
+
FetchContent_Declare(
|
|
23
|
+
parallel-hashmap
|
|
24
|
+
GIT_REPOSITORY https://github.com/greg7mdp/parallel-hashmap.git
|
|
25
|
+
GIT_TAG v1.4.1 # Use latest stable version
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
FetchContent_MakeAvailable(parallel-hashmap)
|
|
20
29
|
# ------ Adding extensions to the library -----
|
|
21
30
|
|
|
22
31
|
# Add library
|
|
@@ -26,13 +35,14 @@ target_sources(mlx_cluster
|
|
|
26
35
|
PUBLIC
|
|
27
36
|
${CMAKE_CURRENT_LIST_DIR}/random_walks/RandomWalk.cpp
|
|
28
37
|
${CMAKE_CURRENT_LIST_DIR}/random_walks/BiasedRandomWalk.cpp
|
|
38
|
+
${CMAKE_CURRENT_LIST_DIR}/random_walks/NeighborSample.cpp
|
|
29
39
|
)
|
|
30
40
|
|
|
31
41
|
target_include_directories(mlx_cluster
|
|
32
42
|
PUBLIC
|
|
33
43
|
${CMAKE_CURRENT_LIST_DIR})
|
|
34
44
|
|
|
35
|
-
target_link_libraries(mlx_cluster PUBLIC mlx)
|
|
45
|
+
target_link_libraries(mlx_cluster PUBLIC mlx phmap)
|
|
36
46
|
|
|
37
47
|
|
|
38
48
|
if(MLX_BUILD_METAL)
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mlx_cluster
|
|
3
|
+
Version: 0.0.6
|
|
4
|
+
Summary: C++ extension for generating random graphs
|
|
5
|
+
Author-email: Vinay Pandya <vinayharshadpandya27@gmail.com>
|
|
6
|
+
Project-URL: Homepage, https://github.com/vinayhpandya/mlx_cluster
|
|
7
|
+
Project-URL: Issues, https://github.com/vinayhpandya/mlx_cluster/Issues
|
|
8
|
+
Classifier: Development Status :: 3 - Alpha
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Programming Language :: C++
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Operating System :: MacOS
|
|
13
|
+
Requires-Python: >=3.8
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
License-File: LICENSE
|
|
16
|
+
Provides-Extra: dev
|
|
17
|
+
Provides-Extra: docs
|
|
18
|
+
Requires-Dist: mlx>=0.27.1; extra == "docs"
|
|
19
|
+
Requires-Dist: mlx-graphs>=0.0.8; extra == "docs"
|
|
20
|
+
Requires-Dist: ipython==8.21.0; extra == "docs"
|
|
21
|
+
Requires-Dist: sphinx>=7.2.6; extra == "docs"
|
|
22
|
+
Requires-Dist: sphinx-book-theme==1.1.0; extra == "docs"
|
|
23
|
+
Requires-Dist: sphinx-autodoc-typehints==1.25.2; extra == "docs"
|
|
24
|
+
Requires-Dist: nbsphinx==0.9.3; extra == "docs"
|
|
25
|
+
Requires-Dist: sphinx-gallery==0.15.0; extra == "docs"
|
|
26
|
+
Provides-Extra: test
|
|
27
|
+
Requires-Dist: mlx-graphs>=0.0.8; extra == "test"
|
|
28
|
+
Requires-Dist: torch>=2.2.0; extra == "test"
|
|
29
|
+
Requires-Dist: mlx>=0.26.0; extra == "test"
|
|
30
|
+
Requires-Dist: pytest==7.4.4; extra == "test"
|
|
31
|
+
Requires-Dist: scipy>=1.13.0; extra == "test"
|
|
32
|
+
Requires-Dist: requests==2.31.0; extra == "test"
|
|
33
|
+
Requires-Dist: fsspec[http]==2024.2.0; extra == "test"
|
|
34
|
+
Requires-Dist: tqdm==4.66.1; extra == "test"
|
|
35
|
+
Dynamic: license-file
|
|
36
|
+
Dynamic: requires-python
|
|
37
|
+
|
|
38
|
+
# MLX-Cluster
|
|
39
|
+
|
|
40
|
+
High-performance graph algorithms optimized for Apple's MLX framework, featuring random walks, biased random walks, and neighbor sampling.
|
|
41
|
+
|
|
42
|
+
[](https://badge.fury.io/py/mlx-cluster)
|
|
43
|
+
[](https://opensource.org/licenses/MIT)
|
|
44
|
+
[](https://www.python.org/downloads/)
|
|
45
|
+
|
|
46
|
+
**[Documentation](https://vinayhpandya.github.io/mlx_cluster/)** | **[Quickstart](https://vinayhpandya.github.io/mlx_cluster/)** |
|
|
47
|
+
|
|
48
|
+
## 🚀 Features
|
|
49
|
+
|
|
50
|
+
- **🔥 MLX Optimized**: Built specifically for Apple's MLX framework with GPU acceleration
|
|
51
|
+
- **⚡ High Performance**: Optimized C++ implementations with Metal shaders for Apple Silicon
|
|
52
|
+
- **🎯 Graph Algorithms**:
|
|
53
|
+
- Uniform random walks
|
|
54
|
+
- Biased random walks (Node2Vec style with p/q parameters)
|
|
55
|
+
- Multi-hop neighbor sampling (GraphSAGE style)
|
|
56
|
+
|
|
57
|
+
## 📦 Installation
|
|
58
|
+
|
|
59
|
+
### From PyPI (Recommended)
|
|
60
|
+
|
|
61
|
+
```bash
|
|
62
|
+
pip install mlx-cluster
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
### From Source
|
|
66
|
+
|
|
67
|
+
```bash
|
|
68
|
+
git clone https://github.com/vinayhpandya/mlx_cluster.git
|
|
69
|
+
cd mlx_cluster
|
|
70
|
+
pip install -e .
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
### Development Installation
|
|
74
|
+
|
|
75
|
+
```bash
|
|
76
|
+
git clone https://github.com/vinayhpandya/mlx_cluster.git
|
|
77
|
+
cd mlx_cluster
|
|
78
|
+
pip install -e . --verbose
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
### Dependencies
|
|
82
|
+
|
|
83
|
+
Required:
|
|
84
|
+
- Python 3.8+
|
|
85
|
+
- MLX framework
|
|
86
|
+
- NumPy
|
|
87
|
+
|
|
88
|
+
Optional (for examples and testing):
|
|
89
|
+
- MLX-Graphs
|
|
90
|
+
- PyTorch (for dataset utilities)
|
|
91
|
+
- pytest
|
|
92
|
+
|
|
93
|
+
## 🔧 Quick Start
|
|
94
|
+
|
|
95
|
+
### Random Walks
|
|
96
|
+
|
|
97
|
+
```python
|
|
98
|
+
import mlx.core as mx
|
|
99
|
+
import numpy as np
|
|
100
|
+
from mlx_cluster import random_walk
|
|
101
|
+
from mlx_graphs.datasets import PlanetoidDataset
|
|
102
|
+
from mlx_graphs.utils.sorting import sort_edge_index
|
|
103
|
+
|
|
104
|
+
# Load dataset
|
|
105
|
+
cora = PlanetoidDataset(name="cora")
|
|
106
|
+
edge_index = cora.graphs[0].edge_index.astype(mx.int64)
|
|
107
|
+
|
|
108
|
+
# Convert to CSR format
|
|
109
|
+
sorted_edge_index = sort_edge_index(edge_index=edge_index)
|
|
110
|
+
row = sorted_edge_index[0][0]
|
|
111
|
+
col = sorted_edge_index[0][1]
|
|
112
|
+
_, counts = np.unique(np.array(row, copy=False), return_counts=True)
|
|
113
|
+
row_ptr = mx.concatenate([mx.array([0]), mx.array(counts.cumsum())])
|
|
114
|
+
|
|
115
|
+
# Generate random walks
|
|
116
|
+
num_walks = 1000
|
|
117
|
+
walk_length = 10
|
|
118
|
+
start_nodes = mx.array(np.random.randint(0, cora.graphs[0].num_nodes, num_walks))
|
|
119
|
+
rand_values = mx.random.uniform(shape=[num_walks, walk_length])
|
|
120
|
+
|
|
121
|
+
mx.eval(rowptr,col, start_nodes, rand_values)
|
|
122
|
+
# Perform walks
|
|
123
|
+
node_sequences, edge_sequences = random_walk(
|
|
124
|
+
row_ptr, col, start_nodes, rand_values, walk_length, stream=mx.gpu
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
print(f"Generated {num_walks} walks of length {walk_length + 1}")
|
|
128
|
+
print(f"Shape: {node_sequences.shape}")
|
|
129
|
+
```
|
|
130
|
+
|
|
131
|
+
### Biased Random Walks (Node2Vec)
|
|
132
|
+
|
|
133
|
+
```python
|
|
134
|
+
from mlx_cluster import rejection_sampling
|
|
135
|
+
|
|
136
|
+
# Biased random walks with p/q parameters
|
|
137
|
+
node_sequences, edge_sequences = rejection_sampling(
|
|
138
|
+
row_ptr, col, start_nodes, walk_length,
|
|
139
|
+
p=1.0, # Return parameter
|
|
140
|
+
q=2.0, # In-out parameter
|
|
141
|
+
stream=mx.gpu
|
|
142
|
+
)
|
|
143
|
+
```
|
|
144
|
+
|
|
145
|
+
### Neighbor Sampling
|
|
146
|
+
|
|
147
|
+
```python
|
|
148
|
+
from mlx_cluster import neighbor_sample
|
|
149
|
+
|
|
150
|
+
# Convert to CSC format (required for neighbor sampling)
|
|
151
|
+
def create_csc_format(edge_index, num_nodes):
|
|
152
|
+
sources, targets = edge_index[0].tolist(), edge_index[1].tolist()
|
|
153
|
+
edges = sorted(zip(sources, targets), key=lambda x: x[1])
|
|
154
|
+
|
|
155
|
+
colptr = np.zeros(num_nodes + 1, dtype=np.int64)
|
|
156
|
+
for _, target in edges:
|
|
157
|
+
colptr[target + 1] += 1
|
|
158
|
+
colptr = np.cumsum(colptr)
|
|
159
|
+
|
|
160
|
+
sorted_sources = [source for source, _ in edges]
|
|
161
|
+
return mx.array(colptr), mx.array(sorted_sources, dtype=mx.int64)
|
|
162
|
+
|
|
163
|
+
colptr, row = create_csc_format(edge_index, cora.graphs[0].num_nodes)
|
|
164
|
+
|
|
165
|
+
# Multi-hop neighbor sampling
|
|
166
|
+
input_nodes = mx.array([0, 1, 2], dtype=mx.int64)
|
|
167
|
+
num_neighbors = [10, 5] # 10 neighbors in first hop, 5 in second
|
|
168
|
+
mx.eval(colptr, row, input_nodes)
|
|
169
|
+
samples, rows, cols, edges = neighbor_sample(
|
|
170
|
+
colptr, row, input_nodes, num_neighbors,
|
|
171
|
+
replace=True, directed=True
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
print(f"Sampled {len(samples)} nodes and {len(edges)} edges")
|
|
175
|
+
```
|
|
176
|
+
|
|
177
|
+
## 📚 Documentation
|
|
178
|
+
|
|
179
|
+
For comprehensive documentation, examples, and API reference, visit:
|
|
180
|
+
[Documentation]()
|
|
181
|
+
|
|
182
|
+
## 🧪 Testing
|
|
183
|
+
|
|
184
|
+
Run the test suite:
|
|
185
|
+
|
|
186
|
+
```bash
|
|
187
|
+
# Install test dependencies
|
|
188
|
+
pip install pytest mlx-graphs torch
|
|
189
|
+
|
|
190
|
+
# Run tests
|
|
191
|
+
pytest -s -v
|
|
192
|
+
```
|
|
193
|
+
|
|
194
|
+
## ⚡ Performance
|
|
195
|
+
|
|
196
|
+
MLX-Cluster is optimized for Apple Silicon and shows significant speedups:
|
|
197
|
+
|
|
198
|
+
- **Apple M1/M2/M3**: 2-5x faster than CPU-only implementations
|
|
199
|
+
- **GPU Acceleration**: Automatic optimization for Metal Performance Shaders
|
|
200
|
+
- **Memory Efficient**: Optimized sparse graph representations
|
|
201
|
+
- **Batch Processing**: Efficient handling of thousands of concurrent walks
|
|
202
|
+
|
|
203
|
+
## 🤝 Contributing
|
|
204
|
+
|
|
205
|
+
We welcome contributions!
|
|
206
|
+
1. Fork the repository
|
|
207
|
+
2. Create your feature branch (`git checkout -b feature/new-feature`)
|
|
208
|
+
3. Commit your changes (`git commit -m 'Add new algorithm'`)
|
|
209
|
+
4. Push to the branch (`git push origin feature/new-feature`)
|
|
210
|
+
5. Open a Pull Request
|
|
211
|
+
For installation and test intructions please visit the documentation
|
|
212
|
+
|
|
213
|
+
## 📄 License
|
|
214
|
+
|
|
215
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
216
|
+
|
|
217
|
+
## 🙏 Acknowledgments
|
|
218
|
+
|
|
219
|
+
- [PyTorch Cluster](https://github.com/rusty1s/pytorch_cluster) for everything
|
|
220
|
+
- [MLX](https://github.com/ml-explore/mlx) for the foundational framework
|
|
221
|
+
- [MLX-Graphs](https://github.com/mlx-graphs/mlx-graphs) for graph utilities and datasets
|
|
222
|
+
|
|
223
|
+
## 📊 Citation
|
|
224
|
+
|
|
225
|
+
If you use MLX-Cluster in your research, please cite:
|
|
226
|
+
|
|
227
|
+
```bibtex
|
|
228
|
+
@software{mlx_cluster,
|
|
229
|
+
author = {Vinay Pandya},
|
|
230
|
+
title = {MLX-Cluster: High-Performance Graph Algorithms for Apple MLX},
|
|
231
|
+
url = {https://github.com/vinayhpandya/mlx_cluster},
|
|
232
|
+
version = {0.0.6},
|
|
233
|
+
year = {2025}
|
|
234
|
+
}
|
|
235
|
+
```
|
|
236
|
+
|
|
237
|
+
## 🔗 Related Projects
|
|
238
|
+
|
|
239
|
+
- [MLX](https://github.com/ml-explore/mlx) - Apple's machine learning framework
|
|
240
|
+
- [MLX-Graphs](https://github.com/mlx-graphs/mlx-graphs) - Graph neural networks for MLX
|
|
241
|
+
- [PyTorch Geometric](https://github.com/pyg-team/pytorch_geometric) - Graph deep learning library
|
|
242
|
+
|
|
243
|
+
---
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
# MLX-Cluster
|
|
2
|
+
|
|
3
|
+
High-performance graph algorithms optimized for Apple's MLX framework, featuring random walks, biased random walks, and neighbor sampling.
|
|
4
|
+
|
|
5
|
+
[](https://badge.fury.io/py/mlx-cluster)
|
|
6
|
+
[](https://opensource.org/licenses/MIT)
|
|
7
|
+
[](https://www.python.org/downloads/)
|
|
8
|
+
|
|
9
|
+
**[Documentation](https://vinayhpandya.github.io/mlx_cluster/)** | **[Quickstart](https://vinayhpandya.github.io/mlx_cluster/)** |
|
|
10
|
+
|
|
11
|
+
## 🚀 Features
|
|
12
|
+
|
|
13
|
+
- **🔥 MLX Optimized**: Built specifically for Apple's MLX framework with GPU acceleration
|
|
14
|
+
- **⚡ High Performance**: Optimized C++ implementations with Metal shaders for Apple Silicon
|
|
15
|
+
- **🎯 Graph Algorithms**:
|
|
16
|
+
- Uniform random walks
|
|
17
|
+
- Biased random walks (Node2Vec style with p/q parameters)
|
|
18
|
+
- Multi-hop neighbor sampling (GraphSAGE style)
|
|
19
|
+
|
|
20
|
+
## 📦 Installation
|
|
21
|
+
|
|
22
|
+
### From PyPI (Recommended)
|
|
23
|
+
|
|
24
|
+
```bash
|
|
25
|
+
pip install mlx-cluster
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
### From Source
|
|
29
|
+
|
|
30
|
+
```bash
|
|
31
|
+
git clone https://github.com/vinayhpandya/mlx_cluster.git
|
|
32
|
+
cd mlx_cluster
|
|
33
|
+
pip install -e .
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
### Development Installation
|
|
37
|
+
|
|
38
|
+
```bash
|
|
39
|
+
git clone https://github.com/vinayhpandya/mlx_cluster.git
|
|
40
|
+
cd mlx_cluster
|
|
41
|
+
pip install -e . --verbose
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
### Dependencies
|
|
45
|
+
|
|
46
|
+
Required:
|
|
47
|
+
- Python 3.8+
|
|
48
|
+
- MLX framework
|
|
49
|
+
- NumPy
|
|
50
|
+
|
|
51
|
+
Optional (for examples and testing):
|
|
52
|
+
- MLX-Graphs
|
|
53
|
+
- PyTorch (for dataset utilities)
|
|
54
|
+
- pytest
|
|
55
|
+
|
|
56
|
+
## 🔧 Quick Start
|
|
57
|
+
|
|
58
|
+
### Random Walks
|
|
59
|
+
|
|
60
|
+
```python
|
|
61
|
+
import mlx.core as mx
|
|
62
|
+
import numpy as np
|
|
63
|
+
from mlx_cluster import random_walk
|
|
64
|
+
from mlx_graphs.datasets import PlanetoidDataset
|
|
65
|
+
from mlx_graphs.utils.sorting import sort_edge_index
|
|
66
|
+
|
|
67
|
+
# Load dataset
|
|
68
|
+
cora = PlanetoidDataset(name="cora")
|
|
69
|
+
edge_index = cora.graphs[0].edge_index.astype(mx.int64)
|
|
70
|
+
|
|
71
|
+
# Convert to CSR format
|
|
72
|
+
sorted_edge_index = sort_edge_index(edge_index=edge_index)
|
|
73
|
+
row = sorted_edge_index[0][0]
|
|
74
|
+
col = sorted_edge_index[0][1]
|
|
75
|
+
_, counts = np.unique(np.array(row, copy=False), return_counts=True)
|
|
76
|
+
row_ptr = mx.concatenate([mx.array([0]), mx.array(counts.cumsum())])
|
|
77
|
+
|
|
78
|
+
# Generate random walks
|
|
79
|
+
num_walks = 1000
|
|
80
|
+
walk_length = 10
|
|
81
|
+
start_nodes = mx.array(np.random.randint(0, cora.graphs[0].num_nodes, num_walks))
|
|
82
|
+
rand_values = mx.random.uniform(shape=[num_walks, walk_length])
|
|
83
|
+
|
|
84
|
+
mx.eval(rowptr,col, start_nodes, rand_values)
|
|
85
|
+
# Perform walks
|
|
86
|
+
node_sequences, edge_sequences = random_walk(
|
|
87
|
+
row_ptr, col, start_nodes, rand_values, walk_length, stream=mx.gpu
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
print(f"Generated {num_walks} walks of length {walk_length + 1}")
|
|
91
|
+
print(f"Shape: {node_sequences.shape}")
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
### Biased Random Walks (Node2Vec)
|
|
95
|
+
|
|
96
|
+
```python
|
|
97
|
+
from mlx_cluster import rejection_sampling
|
|
98
|
+
|
|
99
|
+
# Biased random walks with p/q parameters
|
|
100
|
+
node_sequences, edge_sequences = rejection_sampling(
|
|
101
|
+
row_ptr, col, start_nodes, walk_length,
|
|
102
|
+
p=1.0, # Return parameter
|
|
103
|
+
q=2.0, # In-out parameter
|
|
104
|
+
stream=mx.gpu
|
|
105
|
+
)
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
### Neighbor Sampling
|
|
109
|
+
|
|
110
|
+
```python
|
|
111
|
+
from mlx_cluster import neighbor_sample
|
|
112
|
+
|
|
113
|
+
# Convert to CSC format (required for neighbor sampling)
|
|
114
|
+
def create_csc_format(edge_index, num_nodes):
|
|
115
|
+
sources, targets = edge_index[0].tolist(), edge_index[1].tolist()
|
|
116
|
+
edges = sorted(zip(sources, targets), key=lambda x: x[1])
|
|
117
|
+
|
|
118
|
+
colptr = np.zeros(num_nodes + 1, dtype=np.int64)
|
|
119
|
+
for _, target in edges:
|
|
120
|
+
colptr[target + 1] += 1
|
|
121
|
+
colptr = np.cumsum(colptr)
|
|
122
|
+
|
|
123
|
+
sorted_sources = [source for source, _ in edges]
|
|
124
|
+
return mx.array(colptr), mx.array(sorted_sources, dtype=mx.int64)
|
|
125
|
+
|
|
126
|
+
colptr, row = create_csc_format(edge_index, cora.graphs[0].num_nodes)
|
|
127
|
+
|
|
128
|
+
# Multi-hop neighbor sampling
|
|
129
|
+
input_nodes = mx.array([0, 1, 2], dtype=mx.int64)
|
|
130
|
+
num_neighbors = [10, 5] # 10 neighbors in first hop, 5 in second
|
|
131
|
+
mx.eval(colptr, row, input_nodes)
|
|
132
|
+
samples, rows, cols, edges = neighbor_sample(
|
|
133
|
+
colptr, row, input_nodes, num_neighbors,
|
|
134
|
+
replace=True, directed=True
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
print(f"Sampled {len(samples)} nodes and {len(edges)} edges")
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
## 📚 Documentation
|
|
141
|
+
|
|
142
|
+
For comprehensive documentation, examples, and API reference, visit:
|
|
143
|
+
[Documentation]()
|
|
144
|
+
|
|
145
|
+
## 🧪 Testing
|
|
146
|
+
|
|
147
|
+
Run the test suite:
|
|
148
|
+
|
|
149
|
+
```bash
|
|
150
|
+
# Install test dependencies
|
|
151
|
+
pip install pytest mlx-graphs torch
|
|
152
|
+
|
|
153
|
+
# Run tests
|
|
154
|
+
pytest -s -v
|
|
155
|
+
```
|
|
156
|
+
|
|
157
|
+
## ⚡ Performance
|
|
158
|
+
|
|
159
|
+
MLX-Cluster is optimized for Apple Silicon and shows significant speedups:
|
|
160
|
+
|
|
161
|
+
- **Apple M1/M2/M3**: 2-5x faster than CPU-only implementations
|
|
162
|
+
- **GPU Acceleration**: Automatic optimization for Metal Performance Shaders
|
|
163
|
+
- **Memory Efficient**: Optimized sparse graph representations
|
|
164
|
+
- **Batch Processing**: Efficient handling of thousands of concurrent walks
|
|
165
|
+
|
|
166
|
+
## 🤝 Contributing
|
|
167
|
+
|
|
168
|
+
We welcome contributions!
|
|
169
|
+
1. Fork the repository
|
|
170
|
+
2. Create your feature branch (`git checkout -b feature/new-feature`)
|
|
171
|
+
3. Commit your changes (`git commit -m 'Add new algorithm'`)
|
|
172
|
+
4. Push to the branch (`git push origin feature/new-feature`)
|
|
173
|
+
5. Open a Pull Request
|
|
174
|
+
For installation and test intructions please visit the documentation
|
|
175
|
+
|
|
176
|
+
## 📄 License
|
|
177
|
+
|
|
178
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
179
|
+
|
|
180
|
+
## 🙏 Acknowledgments
|
|
181
|
+
|
|
182
|
+
- [PyTorch Cluster](https://github.com/rusty1s/pytorch_cluster) for everything
|
|
183
|
+
- [MLX](https://github.com/ml-explore/mlx) for the foundational framework
|
|
184
|
+
- [MLX-Graphs](https://github.com/mlx-graphs/mlx-graphs) for graph utilities and datasets
|
|
185
|
+
|
|
186
|
+
## 📊 Citation
|
|
187
|
+
|
|
188
|
+
If you use MLX-Cluster in your research, please cite:
|
|
189
|
+
|
|
190
|
+
```bibtex
|
|
191
|
+
@software{mlx_cluster,
|
|
192
|
+
author = {Vinay Pandya},
|
|
193
|
+
title = {MLX-Cluster: High-Performance Graph Algorithms for Apple MLX},
|
|
194
|
+
url = {https://github.com/vinayhpandya/mlx_cluster},
|
|
195
|
+
version = {0.0.6},
|
|
196
|
+
year = {2025}
|
|
197
|
+
}
|
|
198
|
+
```
|
|
199
|
+
|
|
200
|
+
## 🔗 Related Projects
|
|
201
|
+
|
|
202
|
+
- [MLX](https://github.com/ml-explore/mlx) - Apple's machine learning framework
|
|
203
|
+
- [MLX-Graphs](https://github.com/mlx-graphs/mlx-graphs) - Graph neural networks for MLX
|
|
204
|
+
- [PyTorch Geometric](https://github.com/pyg-team/pytorch_geometric) - Graph deep learning library
|
|
205
|
+
|
|
206
|
+
---
|
|
@@ -2,7 +2,8 @@
|
|
|
2
2
|
#include <nanobind/stl/variant.h>
|
|
3
3
|
#include <random_walks/RandomWalk.h>
|
|
4
4
|
#include <random_walks/BiasedRandomWalk.h>
|
|
5
|
-
|
|
5
|
+
#include <nanobind/stl/vector.h>
|
|
6
|
+
#include <random_walks/NeighborSample.h>
|
|
6
7
|
namespace nb = nanobind;
|
|
7
8
|
using namespace nb::literals;
|
|
8
9
|
using namespace mlx::core;
|
|
@@ -32,6 +33,14 @@ NB_MODULE(_ext, m){
|
|
|
32
33
|
R"(
|
|
33
34
|
Uniform random walks.
|
|
34
35
|
|
|
36
|
+
Args:
|
|
37
|
+
rowptr (mlx.core.array): rowptr of graph in csr format.
|
|
38
|
+
col (mlx.core.array): edges(col) in csr format.
|
|
39
|
+
start_indices (mlx.core.array): starting nodes of graph from which
|
|
40
|
+
sampling will be performed.
|
|
41
|
+
random_values (mlx.corearray): random values (between 0 to 1)
|
|
42
|
+
walk_length (int) : walk length of random graph
|
|
43
|
+
|
|
35
44
|
Returns:
|
|
36
45
|
(nodes, edges) tuple of arrays
|
|
37
46
|
)",
|
|
@@ -65,9 +74,9 @@ NB_MODULE(_ext, m){
|
|
|
65
74
|
on probablity p and q
|
|
66
75
|
|
|
67
76
|
Args:
|
|
68
|
-
rowptr (array): rowptr of graph in csr format.
|
|
69
|
-
col (array): edges in csr format.
|
|
70
|
-
start (array): starting node of graph from which
|
|
77
|
+
rowptr (mlx.core.array): rowptr of graph in csr format.
|
|
78
|
+
col (mlx.core.array): edges in csr format.
|
|
79
|
+
start (mlx.core.array): starting node of graph from which
|
|
71
80
|
biased sampling will be performed.
|
|
72
81
|
walk_length (int) : walk length of random graph
|
|
73
82
|
p : Likelihood of immediately revisiting a node in the walk.
|
|
@@ -78,4 +87,47 @@ NB_MODULE(_ext, m){
|
|
|
78
87
|
(nodes, edges) tuple of arrays
|
|
79
88
|
)",
|
|
80
89
|
nb::rv_policy::move);
|
|
81
|
-
|
|
90
|
+
|
|
91
|
+
m.def(
|
|
92
|
+
"neighbor_sample",
|
|
93
|
+
[](const mx::array& colptr,
|
|
94
|
+
const mx::array& row,
|
|
95
|
+
const mx::array& input_node,
|
|
96
|
+
const std::vector<int64_t>& num_neighbors,
|
|
97
|
+
bool replace = false,
|
|
98
|
+
bool directed = true) {
|
|
99
|
+
|
|
100
|
+
// Call your C++ function
|
|
101
|
+
auto result = neighbor_sample(colptr, row, input_node, num_neighbors, replace, directed);
|
|
102
|
+
|
|
103
|
+
// Convert std::tuple to nanobind tuple with move semantics
|
|
104
|
+
return nb::make_tuple(
|
|
105
|
+
std::move(std::get<0>(result)), // samples
|
|
106
|
+
std::move(std::get<1>(result)), // rows
|
|
107
|
+
std::move(std::get<2>(result)), // cols
|
|
108
|
+
std::move(std::get<3>(result)) // edges
|
|
109
|
+
);
|
|
110
|
+
},
|
|
111
|
+
"colptr"_a,
|
|
112
|
+
"row"_a,
|
|
113
|
+
"input_node"_a,
|
|
114
|
+
"num_neighbors"_a,
|
|
115
|
+
"replace"_a = false,
|
|
116
|
+
"directed"_a = true,
|
|
117
|
+
R"(
|
|
118
|
+
Simple neighbor sampling without primitives.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
colptr: Column pointers (CSC format)
|
|
122
|
+
row: Row indices (CSC format)
|
|
123
|
+
input_node: Input nodes to sample from
|
|
124
|
+
num_neighbors: Number of neighbors per hop
|
|
125
|
+
replace: Sample with replacement
|
|
126
|
+
directed: Directed graph
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
tuple: (samples, rows, cols, edges)
|
|
130
|
+
)",
|
|
131
|
+
nb::rv_policy::move // Add this return value policy
|
|
132
|
+
);
|
|
133
|
+
}
|