xoscar 0.9.0__cp312-cp312-macosx_10_13_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. xoscar/__init__.py +61 -0
  2. xoscar/_utils.cpython-312-darwin.so +0 -0
  3. xoscar/_utils.pxd +36 -0
  4. xoscar/_utils.pyx +246 -0
  5. xoscar/_version.py +693 -0
  6. xoscar/aio/__init__.py +16 -0
  7. xoscar/aio/base.py +86 -0
  8. xoscar/aio/file.py +59 -0
  9. xoscar/aio/lru.py +228 -0
  10. xoscar/aio/parallelism.py +39 -0
  11. xoscar/api.py +527 -0
  12. xoscar/backend.py +67 -0
  13. xoscar/backends/__init__.py +14 -0
  14. xoscar/backends/allocate_strategy.py +160 -0
  15. xoscar/backends/communication/__init__.py +30 -0
  16. xoscar/backends/communication/base.py +315 -0
  17. xoscar/backends/communication/core.py +69 -0
  18. xoscar/backends/communication/dummy.py +253 -0
  19. xoscar/backends/communication/errors.py +20 -0
  20. xoscar/backends/communication/socket.py +444 -0
  21. xoscar/backends/communication/ucx.py +538 -0
  22. xoscar/backends/communication/utils.py +97 -0
  23. xoscar/backends/config.py +157 -0
  24. xoscar/backends/context.py +437 -0
  25. xoscar/backends/core.py +352 -0
  26. xoscar/backends/indigen/__init__.py +16 -0
  27. xoscar/backends/indigen/__main__.py +19 -0
  28. xoscar/backends/indigen/backend.py +51 -0
  29. xoscar/backends/indigen/driver.py +26 -0
  30. xoscar/backends/indigen/fate_sharing.py +221 -0
  31. xoscar/backends/indigen/pool.py +515 -0
  32. xoscar/backends/indigen/shared_memory.py +548 -0
  33. xoscar/backends/message.cpython-312-darwin.so +0 -0
  34. xoscar/backends/message.pyi +255 -0
  35. xoscar/backends/message.pyx +646 -0
  36. xoscar/backends/pool.py +1630 -0
  37. xoscar/backends/router.py +285 -0
  38. xoscar/backends/test/__init__.py +16 -0
  39. xoscar/backends/test/backend.py +38 -0
  40. xoscar/backends/test/pool.py +233 -0
  41. xoscar/batch.py +256 -0
  42. xoscar/collective/__init__.py +27 -0
  43. xoscar/collective/backend/__init__.py +13 -0
  44. xoscar/collective/backend/nccl_backend.py +160 -0
  45. xoscar/collective/common.py +102 -0
  46. xoscar/collective/core.py +737 -0
  47. xoscar/collective/process_group.py +687 -0
  48. xoscar/collective/utils.py +41 -0
  49. xoscar/collective/xoscar_pygloo.cpython-312-darwin.so +0 -0
  50. xoscar/collective/xoscar_pygloo.pyi +239 -0
  51. xoscar/constants.py +23 -0
  52. xoscar/context.cpython-312-darwin.so +0 -0
  53. xoscar/context.pxd +21 -0
  54. xoscar/context.pyx +368 -0
  55. xoscar/core.cpython-312-darwin.so +0 -0
  56. xoscar/core.pxd +51 -0
  57. xoscar/core.pyx +664 -0
  58. xoscar/debug.py +188 -0
  59. xoscar/driver.py +42 -0
  60. xoscar/errors.py +63 -0
  61. xoscar/libcpp.pxd +31 -0
  62. xoscar/metrics/__init__.py +21 -0
  63. xoscar/metrics/api.py +288 -0
  64. xoscar/metrics/backends/__init__.py +13 -0
  65. xoscar/metrics/backends/console/__init__.py +13 -0
  66. xoscar/metrics/backends/console/console_metric.py +82 -0
  67. xoscar/metrics/backends/metric.py +149 -0
  68. xoscar/metrics/backends/prometheus/__init__.py +13 -0
  69. xoscar/metrics/backends/prometheus/prometheus_metric.py +70 -0
  70. xoscar/nvutils.py +717 -0
  71. xoscar/profiling.py +260 -0
  72. xoscar/serialization/__init__.py +20 -0
  73. xoscar/serialization/aio.py +141 -0
  74. xoscar/serialization/core.cpython-312-darwin.so +0 -0
  75. xoscar/serialization/core.pxd +28 -0
  76. xoscar/serialization/core.pyi +57 -0
  77. xoscar/serialization/core.pyx +944 -0
  78. xoscar/serialization/cuda.py +111 -0
  79. xoscar/serialization/exception.py +48 -0
  80. xoscar/serialization/mlx.py +67 -0
  81. xoscar/serialization/numpy.py +82 -0
  82. xoscar/serialization/pyfury.py +37 -0
  83. xoscar/serialization/scipy.py +72 -0
  84. xoscar/serialization/torch.py +180 -0
  85. xoscar/utils.py +522 -0
  86. xoscar/virtualenv/__init__.py +34 -0
  87. xoscar/virtualenv/core.py +268 -0
  88. xoscar/virtualenv/platform.py +56 -0
  89. xoscar/virtualenv/utils.py +100 -0
  90. xoscar/virtualenv/uv.py +321 -0
  91. xoscar-0.9.0.dist-info/METADATA +230 -0
  92. xoscar-0.9.0.dist-info/RECORD +94 -0
  93. xoscar-0.9.0.dist-info/WHEEL +6 -0
  94. xoscar-0.9.0.dist-info/top_level.txt +2 -0
xoscar/batch.py ADDED
@@ -0,0 +1,256 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import inspect
20
+ import textwrap
21
+ from collections import namedtuple
22
+ from dataclasses import dataclass
23
+ from typing import Any, Callable
24
+
25
+ from .core import NO_LOCK_ATTRIBUTE_HINT
26
+
27
+
28
+ def build_args_binder(func, remove_self: bool = True) -> Callable | None:
29
+ try:
30
+ spec = inspect.getfullargspec(func)
31
+ except TypeError: # pragma: no cover
32
+ return None
33
+
34
+ sig_list = list(spec.args)
35
+ args_list = list(spec.args)
36
+ if remove_self:
37
+ args_list = args_list[1:]
38
+
39
+ if spec.varargs:
40
+ sig_list.append(f"*{spec.varargs}")
41
+ args_list.append(spec.varargs)
42
+ elif spec.kwonlyargs:
43
+ sig_list.append("*")
44
+
45
+ sig_list.extend(spec.kwonlyargs)
46
+ args_list.extend(spec.kwonlyargs)
47
+
48
+ if spec.varkw:
49
+ sig_list.append(f"**{spec.varkw}")
50
+ args_list.append(spec.varkw)
51
+
52
+ if getattr(func, "__name__").isidentifier():
53
+ ret_func_name = f"{func.__name__}_binder"
54
+ ret_type_name = f"_Args_{func.__name__}"
55
+ else:
56
+ ret_func_name = f"anon_{id(func)}_binder"
57
+ ret_type_name = f"_ArgsAnon_{id(func)}"
58
+
59
+ func_str = textwrap.dedent(
60
+ f"""
61
+ def {ret_func_name}({', '.join(sig_list)}):
62
+ return {ret_type_name}({', '.join(args_list)})
63
+ """
64
+ )
65
+
66
+ glob_vars = globals().copy()
67
+ glob_vars[ret_type_name] = namedtuple(ret_type_name, args_list)
68
+ loc_vars: dict[str, Any] = dict()
69
+ exec(func_str, glob_vars, loc_vars)
70
+ ext_func = loc_vars[ret_func_name]
71
+ ext_func.__defaults__ = spec.defaults
72
+ ext_func.__kwdefaults__ = spec.kwonlydefaults
73
+
74
+ return ext_func
75
+
76
+
77
+ @dataclass
78
+ class _DelayedArgument:
79
+ args: tuple
80
+ kwargs: dict
81
+
82
+
83
+ class _ExtensibleCallable:
84
+ func: Callable
85
+ batch_func: Callable | None
86
+ is_async: bool
87
+ has_single_func: bool
88
+
89
+ def __call__(self, *args, **kwargs):
90
+ if self.is_async:
91
+ return self._async_call(*args, **kwargs)
92
+ else:
93
+ return self._sync_call(*args, **kwargs)
94
+
95
+ async def _async_call(self, *args, **kwargs):
96
+ try:
97
+ if self.has_single_func:
98
+ return await self.func(*args, **kwargs)
99
+ except NotImplementedError:
100
+ self.has_single_func = False
101
+
102
+ if self.batch_func is not None:
103
+ ret = await self.batch_func([args], [kwargs])
104
+ return None if ret is None else ret[0]
105
+ raise NotImplementedError
106
+
107
+ def _sync_call(self, *args, **kwargs):
108
+ try:
109
+ if self.has_single_func:
110
+ return self.func(*args, **kwargs)
111
+ except NotImplementedError:
112
+ self.has_single_func = False
113
+
114
+ if self.batch_func is not None:
115
+ return self.batch_func([args], [kwargs])[0]
116
+ raise NotImplementedError
117
+
118
+
119
+ class _ExtensibleWrapper(_ExtensibleCallable):
120
+ def __init__(
121
+ self,
122
+ func: Callable,
123
+ batch_func: Callable | None = None,
124
+ bind_func: Callable | None = None,
125
+ is_async: bool = False,
126
+ ):
127
+ self.func = func
128
+ self.batch_func = batch_func
129
+ self.bind_func = bind_func
130
+ self.is_async = is_async
131
+ self.has_single_func = True
132
+
133
+ @staticmethod
134
+ def delay(*args, **kwargs):
135
+ return _DelayedArgument(args=args, kwargs=kwargs)
136
+
137
+ @staticmethod
138
+ def _gen_args_kwargs_list(delays):
139
+ args_list = [delay.args for delay in delays]
140
+ kwargs_list = [delay.kwargs for delay in delays]
141
+ return args_list, kwargs_list
142
+
143
+ async def _async_batch(self, args_list, kwargs_list):
144
+ # when there is only one call in batch, calling one-pass method
145
+ # will be more efficient
146
+ if len(args_list) == 0:
147
+ return []
148
+ elif len(args_list) == 1:
149
+ return [await self._async_call(*args_list[0], **kwargs_list[0])]
150
+ elif self.batch_func:
151
+ return await self.batch_func(args_list, kwargs_list)
152
+ else:
153
+ # this function has no batch implementation
154
+ # call it separately
155
+ tasks = [
156
+ asyncio.create_task(self.func(*args, **kwargs))
157
+ for args, kwargs in zip(args_list, kwargs_list)
158
+ ]
159
+ try:
160
+ return await asyncio.gather(*tasks)
161
+ except asyncio.CancelledError:
162
+ _ = [task.cancel() for task in tasks]
163
+ return await asyncio.gather(*tasks)
164
+
165
+ def _sync_batch(self, args_list, kwargs_list):
166
+ if len(args_list) == 0:
167
+ return []
168
+ elif self.batch_func:
169
+ return self.batch_func(args_list, kwargs_list)
170
+ else:
171
+ # this function has no batch implementation
172
+ # call it separately
173
+ return [
174
+ self.func(*args, **kwargs)
175
+ for args, kwargs in zip(args_list, kwargs_list)
176
+ ]
177
+
178
+ def batch(self, *delays):
179
+ args_list, kwargs_list = self._gen_args_kwargs_list(delays)
180
+ return self.call_with_lists(args_list, kwargs_list)
181
+
182
+ def call_with_lists(self, args_list, kwargs_list):
183
+ if self.is_async:
184
+ return self._async_batch(args_list, kwargs_list)
185
+ else:
186
+ return self._sync_batch(args_list, kwargs_list)
187
+
188
+ def bind(self, *args, **kwargs):
189
+ if self.bind_func is None:
190
+ raise TypeError(f"bind function not exist for method {self.func.__name__}")
191
+ return self.bind_func(*args, **kwargs)
192
+
193
+
194
+ class _ExtensibleAccessor(_ExtensibleCallable):
195
+ func: Callable
196
+ batch_func: Callable | None
197
+
198
+ def __init__(self, func: Callable):
199
+ self.func = func
200
+ self.batch_func = None
201
+ self.bind_func = build_args_binder(func, remove_self=True)
202
+ self.is_async = asyncio.iscoroutinefunction(self.func)
203
+ self.has_single_func = True
204
+
205
+ def batch(self, func: Callable):
206
+ self.batch_func = func
207
+ return self
208
+
209
+ def __get__(self, instance, owner):
210
+ if instance is None:
211
+ # calling from class
212
+ return self.func
213
+
214
+ func = self.func.__get__(instance, owner)
215
+ batch_func = (
216
+ self.batch_func.__get__(instance, owner)
217
+ if self.batch_func is not None
218
+ else None
219
+ )
220
+ bind_func = (
221
+ self.bind_func.__get__(instance, owner)
222
+ if self.bind_func is not None
223
+ else None
224
+ )
225
+
226
+ wrapper = _ExtensibleWrapper(
227
+ func, batch_func=batch_func, bind_func=bind_func, is_async=self.is_async
228
+ )
229
+
230
+ if (
231
+ getattr(self.func, NO_LOCK_ATTRIBUTE_HINT, None) is True
232
+ or getattr(self.batch_func, NO_LOCK_ATTRIBUTE_HINT, None) is True
233
+ ):
234
+ setattr(wrapper, NO_LOCK_ATTRIBUTE_HINT, True)
235
+ return wrapper
236
+
237
+
238
+ def extensible(func: Callable):
239
+ """
240
+ `extensible` means this func could be functionality extended,
241
+ especially for batch operations.
242
+
243
+ Consider remote function calls, each function may have operations
244
+ like opening file, closing file, batching them can help to reduce the cost,
245
+ especially for remote function calls.
246
+
247
+ Parameters
248
+ ----------
249
+ func : callable
250
+ Function
251
+
252
+ Returns
253
+ -------
254
+ func
255
+ """
256
+ return _ExtensibleAccessor(func)
@@ -0,0 +1,27 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .core import (
16
+ RankActor,
17
+ allgather,
18
+ allreduce,
19
+ alltoall,
20
+ broadcast,
21
+ gather,
22
+ init_process_group,
23
+ new_group,
24
+ reduce,
25
+ reduce_scatter,
26
+ scatter,
27
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,160 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # We need to extend cupy's inner class because an actor is a daemonic processes
16
+ # which are not allowed to have children. However, the origin code in cupy
17
+ # will create children processes.
18
+
19
+ import queue
20
+ import socket
21
+ import threading
22
+ from ctypes import sizeof
23
+
24
+ from ...utils import lazy_import
25
+
26
+ cupy = lazy_import("cupy")
27
+
28
+ if cupy is not None:
29
+ import cupyx.distributed
30
+ from cupy.cuda import nccl
31
+ from cupyx.distributed import _klv_utils, _store, _store_actions
32
+
33
+ class ExceptionAwareThreading(threading.Thread):
34
+ def __init__(self, *args, **kwargs):
35
+ super().__init__(*args, **kwargs)
36
+ self._exception = None
37
+ self.q = queue.Queue()
38
+
39
+ def run(self):
40
+ try:
41
+ super().run()
42
+ self.q.put(None)
43
+ except Exception as e:
44
+ self.q.put(e)
45
+
46
+ def join(self):
47
+ super().join()
48
+ if not self.q.empty():
49
+ exception = self.q.get()
50
+ if exception is not None:
51
+ raise exception
52
+
53
+ class TCPStore:
54
+ # This is only used for initialization of nccl so we don't care
55
+ # too much about performance
56
+ def __init__(self, world_size):
57
+ self.storage = {}
58
+ self._thread = None
59
+ self._world_size = world_size
60
+ self._run = 1
61
+ # For implementing a barrier
62
+ self._lock = threading.Lock()
63
+ self._current_barrier = None
64
+
65
+ def __del__(self):
66
+ if not _store._exit_mode:
67
+ self.stop()
68
+
69
+ def _thread_request(self, c_socket):
70
+ with c_socket:
71
+ # Receive in KLV format
72
+ action_bytes = c_socket.recv(sizeof(_klv_utils.action_t))
73
+ if len(action_bytes) > 0:
74
+ action_m = _klv_utils.action_t.from_buffer_copy(action_bytes)
75
+ if action_m.length > 256:
76
+ raise ValueError("Invalid length for message")
77
+ value = bytearray(action_m.value)[: action_m.length]
78
+ r = _store_actions.execute_action(action_m.action, value, self)
79
+ if r is not None:
80
+ c_socket.sendall(r.klv())
81
+
82
+ def _server_loop(self, host, port):
83
+ # This is for minimum info exchange during initialization
84
+ # a single connection allows to implement locking mechanics easily
85
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
86
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
87
+ s.bind((host, port))
88
+ s.listen()
89
+ s.settimeout(0.5)
90
+ while self._run == 1:
91
+ try:
92
+ c_socket, addr = s.accept()
93
+ except socket.timeout:
94
+ continue
95
+
96
+ t = threading.Thread(
97
+ target=self._thread_request, args=(c_socket,), daemon=True
98
+ )
99
+ t.start()
100
+
101
+ def run(self, host=_store._DEFAULT_HOST, port=_store._DEFAULT_PORT):
102
+ # Run the TCP store in a different process
103
+ t = ExceptionAwareThreading(target=self._server_loop, args=(host, port))
104
+ t.start()
105
+ self._thread = t
106
+
107
+ def stop(self):
108
+ if _store._exit_mode:
109
+ return # Prevent shutdown errors
110
+ if self._thread is not None:
111
+ # acquire the lock
112
+ self._lock.acquire()
113
+ self._run = 0
114
+ self._lock.release()
115
+ self._thread.join()
116
+
117
+ class XoscarNCCLBackend(cupyx.distributed.NCCLBackend):
118
+ """Interface that uses NVIDIA's NCCL to perform communications.
119
+
120
+ Args:
121
+ n_devices (int): Total number of devices that will be used in the
122
+ distributed execution.
123
+ rank (int): Unique id of the GPU that the communicator is associated to
124
+ its value needs to be `0 <= rank < n_devices`.
125
+ host (str, optional): host address for the process rendezvous on
126
+ initialization. Defaults to `"127.0.0.1"`.
127
+ port (int, optional): port used for the process rendezvous on
128
+ initialization. Defaults to `13333`.
129
+ use_mpi(bool, optional): switch between MPI and use the included TCP
130
+ server for initialization & synchronization. Defaults to `False`.
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ n_devices,
136
+ rank,
137
+ tcpstore,
138
+ host=_store._DEFAULT_HOST,
139
+ port=_store._DEFAULT_PORT,
140
+ use_mpi=False,
141
+ ):
142
+ self._tcpstore = tcpstore
143
+ super().__init__(n_devices, rank, host, port, use_mpi)
144
+
145
+ def _init_with_tcp_store(self, n_devices, rank, host, port):
146
+ nccl_id = None
147
+ if rank == 0:
148
+ self._tcpstore.run(host, port)
149
+ nccl_id = nccl.get_unique_id()
150
+ # get_unique_id return negative values due to cython issues
151
+ # with bytes && c strings. We shift them by 128 to
152
+ # make them positive and send them as bytes to the proxy store
153
+ shifted_nccl_id = bytes([b + 128 for b in nccl_id])
154
+ self._store_proxy["nccl_id"] = shifted_nccl_id
155
+ self._store_proxy.barrier()
156
+ else:
157
+ self._store_proxy.barrier()
158
+ nccl_id = self._store_proxy["nccl_id"]
159
+ nccl_id = tuple([int(b) - 128 for b in nccl_id])
160
+ self._comm = nccl.NcclCommunicator(n_devices, nccl_id, rank)
@@ -0,0 +1,102 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from enum import IntEnum
15
+ from typing import Dict, Type
16
+
17
+ import numpy as np
18
+
19
+ from ..utils import lazy_import
20
+ from . import xoscar_pygloo as xp
21
+
22
+ ReduceOpMappingGloo: Dict["CollectiveReduceOp", "xp.ReduceOp"] = {}
23
+ AllReduceAlgorithmMappingGloo: Dict["AllReduceAlgorithm", "xp.AllreduceAlgorithm"] = {}
24
+
25
+
26
+ def _register_reduce_op(reduce_op):
27
+ for op_type in reduce_op:
28
+ ReduceOpMappingGloo[op_type] = xp.ReduceOp(op_type)
29
+ return reduce_op
30
+
31
+
32
+ def _register_allreduce_algo(algorithms):
33
+ for algo in algorithms:
34
+ AllReduceAlgorithmMappingGloo[algo] = xp.AllreduceAlgorithm(algo)
35
+ return algorithms
36
+
37
+
38
+ @_register_reduce_op
39
+ class CollectiveReduceOp(IntEnum):
40
+ SUM = 0
41
+ PRODUCT = 1
42
+ MIN = 2
43
+ MAX = 3
44
+ BAND = 4
45
+ BOR = 5
46
+ BXOR = 6
47
+ UNUSED = 7
48
+
49
+
50
+ @_register_allreduce_algo
51
+ class AllReduceAlgorithm(IntEnum):
52
+ UNSPECIFIED = 0
53
+ RING = 1
54
+ BCUBE = 2
55
+
56
+
57
+ TypeMappingGloo: Dict[Type[np.dtype], "xp.GlooDataType_t"] = {
58
+ np.int8: xp.GlooDataType_t.glooInt8, # type: ignore
59
+ np.uint8: xp.GlooDataType_t.glooUint8, # type: ignore
60
+ np.int32: xp.GlooDataType_t.glooInt32, # type: ignore
61
+ np.uint32: xp.GlooDataType_t.glooUint32, # type: ignore
62
+ np.int64: xp.GlooDataType_t.glooInt64, # type: ignore
63
+ np.uint64: xp.GlooDataType_t.glooUint64, # type: ignore
64
+ np.float16: xp.GlooDataType_t.glooFloat16, # type: ignore
65
+ np.float32: xp.GlooDataType_t.glooFloat32, # type: ignore
66
+ np.float64: xp.GlooDataType_t.glooFloat64, # type: ignore
67
+ }
68
+ cupy = lazy_import("cupy")
69
+ if cupy is not None:
70
+ from cupy.cuda import nccl
71
+
72
+ TypeMappingNCCL: Dict[Type[np.dtype], int] = {
73
+ np.int8: nccl.NCCL_INT8, # type: ignore
74
+ np.uint8: nccl.NCCL_UINT8, # type: ignore
75
+ np.int32: nccl.NCCL_INT32, # type: ignore
76
+ np.uint32: nccl.NCCL_UINT32, # type: ignore
77
+ np.int64: nccl.NCCL_INT64, # type: ignore
78
+ np.uint64: nccl.NCCL_UINT64, # type: ignore
79
+ np.float16: nccl.NCCL_FLOAT16, # type: ignore
80
+ np.float32: nccl.NCCL_FLOAT32, # type: ignore
81
+ np.float64: nccl.NCCL_FLOAT64, # type: ignore
82
+ }
83
+
84
+ ReduceOpMappingNCCL: Dict[CollectiveReduceOp, int] = {
85
+ CollectiveReduceOp.SUM: nccl.NCCL_SUM,
86
+ CollectiveReduceOp.PRODUCT: nccl.NCCL_PROD,
87
+ CollectiveReduceOp.MAX: nccl.NCCL_MAX,
88
+ CollectiveReduceOp.MIN: nccl.NCCL_MIN,
89
+ }
90
+
91
+ ReduceOpMappingNCCLStr: Dict[CollectiveReduceOp, str] = {
92
+ CollectiveReduceOp.SUM: "sum",
93
+ CollectiveReduceOp.PRODUCT: "prod",
94
+ CollectiveReduceOp.MAX: "max",
95
+ CollectiveReduceOp.MIN: "min",
96
+ }
97
+ # Some static variables
98
+ INVOKE_ERROR_MESSAGE = "Collective-related functions must be called in a process that is involved in collection communication."
99
+ RANK_ADDRESS_ENV_KEY = "COLLECTIVE_RANK_ADDRESS"
100
+ RENDEZVOUS_MASTER_IP_ENV_KEY = "COLLECTIVE_MASTER_IP"
101
+ RENDEZVOUS_MASTER_PORT_ENV_KEY = "COLLECTIVE_MASTER_PORT"
102
+ COLLECTIVE_DEVICE_ID_ENV_KEY = "COLLECTIVE_DEVICE_ID_FOR_AN_ACTOR"