rixa 0.0.1__tar.gz → 0.0.2.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. rixa-0.0.2.dev0/MANIFEST.in +1 -0
  2. {rixa-0.0.1 → rixa-0.0.2.dev0}/PKG-INFO +3 -3
  3. {rixa-0.0.1 → rixa-0.0.2.dev0}/README.md +1 -1
  4. {rixa-0.0.1 → rixa-0.0.2.dev0}/pyproject.toml +7 -3
  5. rixa-0.0.2.dev0/setup.py +50 -0
  6. rixa-0.0.2.dev0/src/rixa/bindings/core.c +242 -0
  7. rixa-0.0.2.dev0/src/rixa/bindings/pmix.h +681 -0
  8. rixa-0.0.2.dev0/src/rixa/bindings/pmix_abi_support.h +462 -0
  9. rixa-0.0.2.dev0/src/rixa/bindings/pmix_abi_support_bottom.h +142 -0
  10. rixa-0.0.2.dev0/src/rixa/bindings/pmix_fns.h +676 -0
  11. rixa-0.0.2.dev0/src/rixa/bindings/pmix_macros.h +1101 -0
  12. rixa-0.0.2.dev0/src/rixa/bindings/pmix_types.h +1326 -0
  13. rixa-0.0.2.dev0/src/rixa/bindings/rixa_C10.cpp +161 -0
  14. rixa-0.0.2.dev0/src/rixa/bindings/rixa_pmix_store.c +170 -0
  15. rixa-0.0.2.dev0/src/rixa/bindings/rixa_pmix_store.h +63 -0
  16. rixa-0.0.2.dev0/src/rixa/pytorch.py +83 -0
  17. {rixa-0.0.1 → rixa-0.0.2.dev0}/src/rixa.egg-info/PKG-INFO +3 -3
  18. rixa-0.0.2.dev0/src/rixa.egg-info/SOURCES.txt +28 -0
  19. {rixa-0.0.1 → rixa-0.0.2.dev0}/src/rixa.egg-info/requires.txt +1 -1
  20. rixa-0.0.2.dev0/src/rixa.egg-info/top_level.txt +1 -0
  21. {rixa-0.0.1 → rixa-0.0.2.dev0}/tests/test_pytorch_init.py +3 -1
  22. rixa-0.0.2.dev0/tests/test_pytorch_init_gpu.py +57 -0
  23. rixa-0.0.1/setup.py +0 -18
  24. rixa-0.0.1/src/bindings/core.c +0 -322
  25. rixa-0.0.1/src/rixa/pytorch.py +0 -36
  26. rixa-0.0.1/src/rixa.egg-info/SOURCES.txt +0 -17
  27. rixa-0.0.1/src/rixa.egg-info/top_level.txt +0 -2
  28. {rixa-0.0.1 → rixa-0.0.2.dev0}/LICENSE +0 -0
  29. {rixa-0.0.1 → rixa-0.0.2.dev0}/setup.cfg +0 -0
  30. {rixa-0.0.1 → rixa-0.0.2.dev0}/src/rixa/PMIx_core.pyi +0 -0
  31. {rixa-0.0.1 → rixa-0.0.2.dev0}/src/rixa/__init__.py +0 -0
  32. {rixa-0.0.1 → rixa-0.0.2.dev0}/src/rixa/nvshmem.py +0 -0
  33. {rixa-0.0.1 → rixa-0.0.2.dev0}/src/rixa.egg-info/dependency_links.txt +0 -0
  34. {rixa-0.0.1 → rixa-0.0.2.dev0}/tests/test_nvshmem_init.py +0 -0
  35. {rixa-0.0.1 → rixa-0.0.2.dev0}/tests/test_pmix.py +0 -0
@@ -0,0 +1 @@
1
+ recursive-include src/rixa/bindings *.cpp *.h
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rixa
3
- Version: 0.0.1
3
+ Version: 0.0.2.dev0
4
4
  Summary: PMIx bootstrap method for modern AI/ML applications
5
5
  Author-email: Mateusz Kapusta <mr.kapusta@student.uw.edu.pl>
6
6
  Keywords: PMIx,HPC,distributed-computing,machine-learning,slurm
@@ -15,7 +15,7 @@ Requires-Python: >=3.10
15
15
  Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
17
  Provides-Extra: pytorch
18
- Requires-Dist: torch>=2.0; extra == "pytorch"
18
+ Requires-Dist: torch>=2.7.0; extra == "pytorch"
19
19
  Requires-Dist: numpy>=1.26; extra == "pytorch"
20
20
  Provides-Extra: nvshmem
21
21
  Requires-Dist: nvshmem4py; extra == "nvshmem"
@@ -82,7 +82,7 @@ Remember to manually finalize the backend!
82
82
  nvshmem.finalize()
83
83
  ```
84
84
 
85
- ## Installation
85
+ ## Development
86
86
  Library can be compiled from source with standard `setuptools` and requires PMIx library supporting version >5.0.
87
87
  Development of the package is managed with `pixi` that can be used to also bring all necessary libraries for testing and development.
88
88
  It can be locally tested using `prrte` and more recently with OpenMPI provided by `pixi`.
@@ -58,7 +58,7 @@ Remember to manually finalize the backend!
58
58
  nvshmem.finalize()
59
59
  ```
60
60
 
61
- ## Installation
61
+ ## Development
62
62
  Library can be compiled from source with standard `setuptools` and requires PMIx library supporting version >5.0.
63
63
  Development of the package is managed with `pixi` that can be used to also bring all necessary libraries for testing and development.
64
64
  It can be locally tested using `prrte` and more recently with OpenMPI provided by `pixi`.
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rixa"
3
- version = "0.0.1"
3
+ version = "0.0.2.dev"
4
4
  description = "PMIx bootstrap method for modern AI/ML applications"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10"
@@ -22,10 +22,14 @@ classifiers = [
22
22
  keywords = ["PMIx", "HPC", "distributed-computing", "machine-learning", "slurm"]
23
23
 
24
24
  [project.optional-dependencies]
25
- pytorch = ["torch>=2.0","numpy>=1.26"]
25
+ pytorch = ["torch>=2.7.0","numpy>=1.26"]
26
26
  nvshmem = ["nvshmem4py","numpy>=1.26"]
27
27
 
28
28
  [build-system]
29
- requires = ["setuptools>=61.0", "wheel", "pybind11"]
29
+ requires = [
30
+ "setuptools>=61.0",
31
+ "wheel",
32
+ "pybind11",
33
+ ]
30
34
  build-backend = "setuptools.build_meta"
31
35
 
@@ -0,0 +1,50 @@
1
+ from setuptools import setup, Extension
2
+ import importlib.util
3
+
4
+ core_ext = Extension(
5
+ "rixa.PMIx_core",
6
+ sources=["src/rixa/bindings/core.c", "src/rixa/bindings/rixa_pmix_store.c"],
7
+ include_dirs=["include"],
8
+ libraries=["pmix"],
9
+ language="c",
10
+ extra_compile_args=["-O3"],
11
+ )
12
+
13
+ cmdclass = {}
14
+ ext_modules = [core_ext]
15
+
16
+ if importlib.util.find_spec("torch"):
17
+ from torch.utils.cpp_extension import CppExtension, BuildExtension
18
+ import torch
19
+
20
+ torch_lib_path = torch.utils.cpp_extension.library_paths()
21
+
22
+ torch_ext = CppExtension(
23
+ name="rixa._rixa_torch",
24
+ sources=[
25
+ "src/rixa/bindings/rixa_pmix_store.c",
26
+ "src/rixa/bindings/rixa_C10.cpp",
27
+ ],
28
+ include_dirs=["include", "src/"],
29
+ libraries=["pmix", "torch", "c10", "torch_python"],
30
+ library_dirs=torch_lib_path,
31
+ extra_compile_args={
32
+ "cxx": [
33
+ "-std=c++17",
34
+ ], # only for .cpp files
35
+ "cc": ["-O3"], # only for .c files
36
+ },
37
+ extra_link_args=[
38
+ "-Wl,--no-as-needed",
39
+ "-ltorch_python",
40
+ "-Wl,--as-needed",
41
+ ],
42
+ )
43
+ ext_modules.append(torch_ext)
44
+
45
+ cmdclass["build_ext"] = BuildExtension.with_options(use_ninja=False)
46
+
47
+ setup(
48
+ ext_modules=ext_modules,
49
+ cmdclass=cmdclass,
50
+ )
@@ -0,0 +1,242 @@
1
+ #include <Python.h>
2
+ #include <pmix.h>
3
+ #include <pmix_common.h>
4
+ #include <stdio.h>
5
+ #include <string.h>
6
+ #include <structmember.h>
7
+ #include <unistd.h>
8
+
9
+ #include "pyerrors.h"
10
+ #include "pyport.h"
11
+ #include "rixa_pmix_store.h"
12
+
13
+ typedef struct {
14
+ PyObject_HEAD;
15
+ rixa_store store;
16
+ } PyPMIx;
17
+ static GlobalPMIxState state;
18
+
19
+ Py_ssize_t get_string_from_python(PyObject *val_obj, const char **out) {
20
+ Py_ssize_t return_val;
21
+ if (PyBytes_Check(val_obj)) {
22
+ PyBytes_AsStringAndSize(val_obj, (char **)out, &return_val);
23
+ } else if (PyUnicode_Check(val_obj)) {
24
+ *out = PyUnicode_AsUTF8(val_obj);
25
+ return_val = strlen(*out);
26
+ } else {
27
+ return 0;
28
+ }
29
+ return return_val;
30
+ }
31
+
32
+ static int PMIxObjInit(PyObject *self, PyObject *args) {
33
+ PyPMIx *self_pmix = (PyPMIx *)self;
34
+ if (!PyArg_ParseTuple(args, "i", &self_pmix->store.timeout)) {
35
+ return -1;
36
+ }
37
+ // self_pmix->timeout = 30;
38
+
39
+ if (state.init) {
40
+ PyErr_SetString(PyExc_TypeError, "PMIx already started!");
41
+ return -1;
42
+ }
43
+
44
+ pmix_status_t rc = PMIx_Init(&state.proc, NULL, 0);
45
+ if (rc != PMIX_SUCCESS) {
46
+ PyErr_SetString(PyExc_TypeError, "Failed to init PMIx!");
47
+ return -1;
48
+ }
49
+
50
+ state.init = 1;
51
+ return 0;
52
+ }
53
+
54
+ static PyTypeObject PMIxType = {
55
+ PyVarObject_HEAD_INIT(NULL, 0).tp_name = "PMIx_core.PMIxStore",
56
+ .tp_doc = "Custom PMX storage for pytorch",
57
+ .tp_basicsize = sizeof(PyPMIx),
58
+ .tp_itemsize = 0,
59
+ .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
60
+ .tp_new = PyType_GenericNew,
61
+ .tp_init = (initproc)PMIxObjInit, // Default constructor
62
+ };
63
+ // IMPLEMENTATIONS
64
+
65
+ // 1. GET RANK
66
+ static PyObject *get_rank_python(PyObject *self, PyObject *Py_UNUSED(ignored)) {
67
+ int rank = rixa_get_rank(&state);
68
+ if (rank < 0) {
69
+ PyErr_SetString(PyExc_RuntimeError,
70
+ "Pmix runtime not started, failed to query rank!");
71
+ return NULL;
72
+ }
73
+
74
+ PyObject *result = PyLong_FromLong((long)rank);
75
+ return result;
76
+ }
77
+
78
+ // 2. GET WORLD
79
+ static PyObject *get_world_python(PyObject *self,
80
+ PyObject *Py_UNUSED(ignored)) {
81
+
82
+ int world = rixa_get_world(&state);
83
+ if (world < 0) {
84
+ PyErr_SetString(PyExc_RuntimeError,
85
+ "Pmix runtime not started, failed to query world!");
86
+ return NULL;
87
+ }
88
+
89
+ PyObject *result = PyLong_FromLong((long)world);
90
+ return result;
91
+ }
92
+
93
+ // 3. SET
94
+ static PyObject *set(PyObject *self, PyObject *args) {
95
+
96
+ PyObject *key_obj, *val_obj;
97
+ const char *key, *val;
98
+ if (!PyArg_ParseTuple(args, "OO", &key_obj, &val_obj)) {
99
+ return NULL;
100
+ }
101
+ Py_ssize_t size_key = get_string_from_python(key_obj, &key);
102
+ if (!size_key) {
103
+ return NULL;
104
+ }
105
+
106
+ Py_ssize_t size_val = get_string_from_python(val_obj, &val);
107
+ if (!size_val) {
108
+ return NULL;
109
+ }
110
+
111
+ Rixa_Error status = rixa_set(&state, NULL, key, val, size_val);
112
+ if (status != RIXA_SUCCESS) {
113
+ PyErr_Format(PyExc_RuntimeError, "(set) failed to push key '%s': %s", key,
114
+ PMIx_Error_string(status));
115
+ return NULL;
116
+ }
117
+
118
+ Py_INCREF(Py_None);
119
+ return Py_None;
120
+ }
121
+
122
+ // 4. GET
123
+ static PyObject *get(PyObject *self, PyObject *args) {
124
+ PyPMIx *self_pmix = (PyPMIx *)self;
125
+
126
+ PyObject *key_obj;
127
+ const char *key;
128
+ if (!PyArg_ParseTuple(args, "O", &key_obj)) {
129
+ return NULL;
130
+ }
131
+ Py_ssize_t size_key = get_string_from_python(key_obj, &key);
132
+ if (!size_key) {
133
+ return NULL;
134
+ }
135
+
136
+ Rixa_bytes out;
137
+ Rixa_Error status = rixa_get(&state, &self_pmix->store, key, &out);
138
+
139
+ if (status == RIXA_TIMEOUT) {
140
+ PyErr_Format(PyExc_TimeoutError, "(get) Timeout to get key '%s'!", key);
141
+ return NULL;
142
+ }
143
+ if (status != RIXA_SUCCESS) {
144
+ PyErr_Format(PyExc_TypeError, "(get) Failed to get key '%s'!", key);
145
+ return NULL;
146
+ }
147
+
148
+ PyObject *result;
149
+ result = PyBytes_FromStringAndSize(out.bytes, out.size);
150
+ free(out.bytes);
151
+ return result;
152
+ }
153
+
154
+ // 5. WATI
155
+ static PyObject *wait_for_keys(PyObject *self, PyObject *args) {
156
+ float delta_T = 0.1; // Fraction of second for every retry
157
+ float total_sleep = 0; // total amount of time spend on sleeping
158
+ PyPMIx *self_pmix = (PyPMIx *)self;
159
+ PyObject *keys_list;
160
+ int timeout;
161
+ if (!PyArg_ParseTuple(args, "O|i", &keys_list, &timeout)) {
162
+ return NULL;
163
+ }
164
+ if (timeout < 0) {
165
+ timeout = self_pmix->store.timeout;
166
+ }
167
+ if (!PyList_Check(keys_list)) {
168
+ PyErr_SetString(PyExc_TypeError, "keys must be a list");
169
+ return NULL;
170
+ }
171
+
172
+ Py_ssize_t n = PyList_Size(keys_list);
173
+ char keys[n][PMIX_MAX_KEYLEN];
174
+
175
+ for (Py_ssize_t i = 0; i < n; i++) {
176
+ PyObject *key_obj = PyList_GetItem(keys_list, i);
177
+ const char *key;
178
+ if (PyUnicode_Check(key_obj)) {
179
+ key = PyUnicode_AsUTF8(key_obj);
180
+ } else if (PyBytes_Check(key_obj)) {
181
+ key = PyBytes_AsString(key_obj);
182
+ } else {
183
+ PyErr_SetString(PyExc_TypeError, "key must be str or bytes");
184
+ return NULL;
185
+ }
186
+ strncpy(keys[i], key, PMIX_MAX_KEYLEN);
187
+ }
188
+ Rixa_Error status = rixa_wait(&state, &self_pmix->store, keys, n, timeout);
189
+ if (status == RIXA_TIMEOUT) {
190
+ PyErr_SetString(PyExc_RuntimeError, "Timout reached!");
191
+ return NULL;
192
+ } else if (status != RIXA_SUCCESS) {
193
+ PyErr_SetString(PyExc_RuntimeError, "Error encoutered, failed!");
194
+ return NULL;
195
+ }
196
+
197
+ Py_INCREF(Py_None);
198
+ return Py_None;
199
+ }
200
+
201
+ // -1. CLEAN UP
202
+ void PMIxCleanup(void) {
203
+ if (state.init == 1) {
204
+ PMIx_Finalize(NULL, 0);
205
+ }
206
+ }
207
+
208
+ static PyMethodDef Custom_methods[] = {
209
+ {"get_rank", get_rank_python, METH_NOARGS, "Get the process rank"},
210
+ {"get_world", get_world_python, METH_NOARGS, "Get the world size"},
211
+ {"set", set, METH_VARARGS, "set a key-value pair"},
212
+ {"get", get, METH_VARARGS, "get a value for given key"},
213
+ {"wait", wait_for_keys, METH_VARARGS, "wait for arrays of keys"},
214
+ {NULL}};
215
+
216
+ static struct PyModuleDef coremodule = {
217
+ PyModuleDef_HEAD_INIT, "_core", NULL, -1, NULL, NULL, NULL, NULL, NULL};
218
+
219
+ PyMODINIT_FUNC PyInit_PMIx_core(void) {
220
+ PyObject *m;
221
+
222
+ PMIxType.tp_methods = Custom_methods;
223
+
224
+ if (PyType_Ready(&PMIxType) < 0)
225
+ return NULL;
226
+
227
+ m = PyModule_Create(&coremodule);
228
+ if (m == NULL)
229
+ return NULL;
230
+
231
+ Py_INCREF(&PMIxType);
232
+ if (PyModule_AddObject(m, "PMIxStore", (PyObject *)&PMIxType) < 0) {
233
+ Py_DECREF(&PMIxType);
234
+ Py_DECREF(m);
235
+ return NULL;
236
+ }
237
+
238
+ Py_AtExit(PMIxCleanup);
239
+
240
+ state.init = 0; // set that we can init PMIX
241
+ return m;
242
+ }