xoscar 0.9.0__cp312-cp312-macosx_10_13_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. xoscar/__init__.py +61 -0
  2. xoscar/_utils.cpython-312-darwin.so +0 -0
  3. xoscar/_utils.pxd +36 -0
  4. xoscar/_utils.pyx +246 -0
  5. xoscar/_version.py +693 -0
  6. xoscar/aio/__init__.py +16 -0
  7. xoscar/aio/base.py +86 -0
  8. xoscar/aio/file.py +59 -0
  9. xoscar/aio/lru.py +228 -0
  10. xoscar/aio/parallelism.py +39 -0
  11. xoscar/api.py +527 -0
  12. xoscar/backend.py +67 -0
  13. xoscar/backends/__init__.py +14 -0
  14. xoscar/backends/allocate_strategy.py +160 -0
  15. xoscar/backends/communication/__init__.py +30 -0
  16. xoscar/backends/communication/base.py +315 -0
  17. xoscar/backends/communication/core.py +69 -0
  18. xoscar/backends/communication/dummy.py +253 -0
  19. xoscar/backends/communication/errors.py +20 -0
  20. xoscar/backends/communication/socket.py +444 -0
  21. xoscar/backends/communication/ucx.py +538 -0
  22. xoscar/backends/communication/utils.py +97 -0
  23. xoscar/backends/config.py +157 -0
  24. xoscar/backends/context.py +437 -0
  25. xoscar/backends/core.py +352 -0
  26. xoscar/backends/indigen/__init__.py +16 -0
  27. xoscar/backends/indigen/__main__.py +19 -0
  28. xoscar/backends/indigen/backend.py +51 -0
  29. xoscar/backends/indigen/driver.py +26 -0
  30. xoscar/backends/indigen/fate_sharing.py +221 -0
  31. xoscar/backends/indigen/pool.py +515 -0
  32. xoscar/backends/indigen/shared_memory.py +548 -0
  33. xoscar/backends/message.cpython-312-darwin.so +0 -0
  34. xoscar/backends/message.pyi +255 -0
  35. xoscar/backends/message.pyx +646 -0
  36. xoscar/backends/pool.py +1630 -0
  37. xoscar/backends/router.py +285 -0
  38. xoscar/backends/test/__init__.py +16 -0
  39. xoscar/backends/test/backend.py +38 -0
  40. xoscar/backends/test/pool.py +233 -0
  41. xoscar/batch.py +256 -0
  42. xoscar/collective/__init__.py +27 -0
  43. xoscar/collective/backend/__init__.py +13 -0
  44. xoscar/collective/backend/nccl_backend.py +160 -0
  45. xoscar/collective/common.py +102 -0
  46. xoscar/collective/core.py +737 -0
  47. xoscar/collective/process_group.py +687 -0
  48. xoscar/collective/utils.py +41 -0
  49. xoscar/collective/xoscar_pygloo.cpython-312-darwin.so +0 -0
  50. xoscar/collective/xoscar_pygloo.pyi +239 -0
  51. xoscar/constants.py +23 -0
  52. xoscar/context.cpython-312-darwin.so +0 -0
  53. xoscar/context.pxd +21 -0
  54. xoscar/context.pyx +368 -0
  55. xoscar/core.cpython-312-darwin.so +0 -0
  56. xoscar/core.pxd +51 -0
  57. xoscar/core.pyx +664 -0
  58. xoscar/debug.py +188 -0
  59. xoscar/driver.py +42 -0
  60. xoscar/errors.py +63 -0
  61. xoscar/libcpp.pxd +31 -0
  62. xoscar/metrics/__init__.py +21 -0
  63. xoscar/metrics/api.py +288 -0
  64. xoscar/metrics/backends/__init__.py +13 -0
  65. xoscar/metrics/backends/console/__init__.py +13 -0
  66. xoscar/metrics/backends/console/console_metric.py +82 -0
  67. xoscar/metrics/backends/metric.py +149 -0
  68. xoscar/metrics/backends/prometheus/__init__.py +13 -0
  69. xoscar/metrics/backends/prometheus/prometheus_metric.py +70 -0
  70. xoscar/nvutils.py +717 -0
  71. xoscar/profiling.py +260 -0
  72. xoscar/serialization/__init__.py +20 -0
  73. xoscar/serialization/aio.py +141 -0
  74. xoscar/serialization/core.cpython-312-darwin.so +0 -0
  75. xoscar/serialization/core.pxd +28 -0
  76. xoscar/serialization/core.pyi +57 -0
  77. xoscar/serialization/core.pyx +944 -0
  78. xoscar/serialization/cuda.py +111 -0
  79. xoscar/serialization/exception.py +48 -0
  80. xoscar/serialization/mlx.py +67 -0
  81. xoscar/serialization/numpy.py +82 -0
  82. xoscar/serialization/pyfury.py +37 -0
  83. xoscar/serialization/scipy.py +72 -0
  84. xoscar/serialization/torch.py +180 -0
  85. xoscar/utils.py +522 -0
  86. xoscar/virtualenv/__init__.py +34 -0
  87. xoscar/virtualenv/core.py +268 -0
  88. xoscar/virtualenv/platform.py +56 -0
  89. xoscar/virtualenv/utils.py +100 -0
  90. xoscar/virtualenv/uv.py +321 -0
  91. xoscar-0.9.0.dist-info/METADATA +230 -0
  92. xoscar-0.9.0.dist-info/RECORD +94 -0
  93. xoscar-0.9.0.dist-info/WHEEL +6 -0
  94. xoscar-0.9.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,687 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from abc import ABC, abstractmethod
16
+ from typing import Any, List, Optional
17
+
18
+ from ..utils import is_linux, lazy_import
19
+ from . import xoscar_pygloo as xp
20
+ from .common import (
21
+ RENDEZVOUS_MASTER_IP_ENV_KEY,
22
+ RENDEZVOUS_MASTER_PORT_ENV_KEY,
23
+ AllReduceAlgorithm,
24
+ AllReduceAlgorithmMappingGloo,
25
+ CollectiveReduceOp,
26
+ ReduceOpMappingGloo,
27
+ TypeMappingGloo,
28
+ )
29
+ from .utils import convert_data_to_cp_array, convert_data_to_np_array
30
+
31
+ cupy = lazy_import("cupy")
32
+ if cupy is not None:
33
+ from .backend.nccl_backend import TCPStore, XoscarNCCLBackend
34
+ from .common import ReduceOpMappingNCCL, ReduceOpMappingNCCLStr, TypeMappingNCCL
35
+
36
+
37
+ class _World:
38
+ def __init__(self):
39
+ self._store = None
40
+ self._device = None
41
+ self._backend = None
42
+ self._nccl_size = 0
43
+
44
+ @property
45
+ def store(self):
46
+ return self._store
47
+
48
+ @property
49
+ def device(self):
50
+ return self._device
51
+
52
+ @store.setter # type: ignore
53
+ def store(self, store):
54
+ self._store = store
55
+
56
+ @device.setter # type: ignore
57
+ def device(self, device):
58
+ self._device = device
59
+
60
+
61
+ _world = _World()
62
+
63
+
64
+ class ProcessGroup(ABC):
65
+ class Options:
66
+ master_ip: Optional[str] = None
67
+ master_port: Optional[int] = None
68
+
69
+ def __init__(
70
+ self,
71
+ rank: int,
72
+ world_size: int,
73
+ group_name: Optional[str] = None,
74
+ pg_options: Optional[Options] = None,
75
+ ):
76
+ self._rank = rank
77
+ self._world_size = world_size
78
+ self._group_name = group_name
79
+ self._option = pg_options
80
+
81
+ @property
82
+ def rank(self):
83
+ return self._rank
84
+
85
+ @property
86
+ def world_size(self):
87
+ return self._world_size
88
+
89
+ @property
90
+ def name(self):
91
+ return self._group_name
92
+
93
+ @property
94
+ def options(self):
95
+ return self._option
96
+
97
+ @abstractmethod
98
+ def allreduce(self, *args, **kwargs):
99
+ """All reduce function"""
100
+
101
+ @abstractmethod
102
+ def reduce(self, *args, **kwargs):
103
+ """Reduce function"""
104
+
105
+ @abstractmethod
106
+ def allgather(self, *args, **kwargs):
107
+ """All gather function"""
108
+
109
+ @abstractmethod
110
+ def gather(self, *args, **kwargs):
111
+ """Gather function"""
112
+
113
+ @abstractmethod
114
+ def scatter(self, *args, **kwargs):
115
+ """Scatter function"""
116
+
117
+ @abstractmethod
118
+ def reduce_scatter(self, *args, **kwargs):
119
+ """Reduce scatter function"""
120
+
121
+ @abstractmethod
122
+ def alltoall(self, *args, **kwargs):
123
+ """All to all function"""
124
+
125
+ @abstractmethod
126
+ def broadcast(self, *args, **kwargs):
127
+ """Broadcast function"""
128
+
129
+
130
+ class ProcessGroupGloo(ProcessGroup):
131
+ def __init__(
132
+ self,
133
+ ip: str,
134
+ rank: int,
135
+ world_size: int,
136
+ group_name: Optional[str] = None,
137
+ pg_options: Optional[ProcessGroup.Options] = None,
138
+ ):
139
+ super().__init__(rank, world_size, group_name, pg_options)
140
+ if _world.store is None:
141
+ master_ip = (
142
+ pg_options.master_ip
143
+ if pg_options is not None
144
+ else os.environ.get(RENDEZVOUS_MASTER_IP_ENV_KEY, None)
145
+ )
146
+ master_port = (
147
+ pg_options.master_port
148
+ if pg_options is not None
149
+ else os.environ.get(RENDEZVOUS_MASTER_PORT_ENV_KEY, None)
150
+ )
151
+ if master_ip is None or master_port is None:
152
+ raise ValueError("Cannot find master ip or port for rendezvous")
153
+
154
+ opt = xp.rendezvous.TCPStoreOptions()
155
+ opt.port = int(master_port)
156
+ opt.numWorkers = world_size
157
+ opt.isServer = rank == 0
158
+
159
+ store = xp.rendezvous.TCPStore(master_ip, opt)
160
+ if not is_linux():
161
+ attr = xp.transport.uv.attr(ip) # type: ignore
162
+ dev = xp.transport.uv.CreateDevice(attr) # type: ignore
163
+ else:
164
+ attr = xp.transport.tcp.attr(ip)
165
+ dev = xp.transport.tcp.CreateDevice(attr) # type: ignore
166
+ _world.store = store # type: ignore
167
+ _world.device = dev # type: ignore
168
+ else:
169
+ store = _world.store
170
+ dev = _world.device
171
+
172
+ prefix_store = xp.rendezvous.PrefixStore(group_name or str(world_size), store) # type: ignore
173
+ context = xp.rendezvous.Context(rank, world_size)
174
+ context.connectFullMesh(prefix_store, dev)
175
+ self._context = context
176
+
177
+ def reduce(
178
+ self,
179
+ send_data: Any,
180
+ recv_data: Any,
181
+ op: CollectiveReduceOp = CollectiveReduceOp.SUM,
182
+ root: Optional[int] = 0,
183
+ tag: Optional[int] = 0,
184
+ ):
185
+ send_buf = convert_data_to_np_array(send_data)
186
+ recv_buf = convert_data_to_np_array(recv_data)
187
+ size = send_buf.size
188
+ dtype = send_buf.dtype
189
+ sendptr = send_buf.ctypes.data
190
+ recvptr = recv_buf.ctypes.data
191
+ gloo_type = TypeMappingGloo[dtype.type]
192
+ xp.reduce(
193
+ self._context,
194
+ sendptr,
195
+ recvptr,
196
+ size,
197
+ gloo_type,
198
+ ReduceOpMappingGloo[op],
199
+ root,
200
+ tag,
201
+ )
202
+
203
+ def allreduce(
204
+ self,
205
+ send_data: Any,
206
+ recv_data: Any,
207
+ op: CollectiveReduceOp = CollectiveReduceOp.SUM,
208
+ algorithm: AllReduceAlgorithm = AllReduceAlgorithm.RING,
209
+ tag: Optional[int] = 0,
210
+ ):
211
+ send_buf = convert_data_to_np_array(send_data)
212
+ recv_buf = convert_data_to_np_array(recv_data)
213
+ size = send_buf.size
214
+ dtype = send_buf.dtype
215
+ sendptr = send_buf.ctypes.data
216
+ recvptr = recv_buf.ctypes.data
217
+ gloo_type = TypeMappingGloo[dtype.type]
218
+ xp.allreduce(
219
+ self._context,
220
+ sendptr,
221
+ recvptr,
222
+ size,
223
+ gloo_type,
224
+ ReduceOpMappingGloo[op],
225
+ AllReduceAlgorithmMappingGloo[algorithm],
226
+ tag, # type: ignore
227
+ )
228
+
229
+ def gather(
230
+ self,
231
+ send_data: Any,
232
+ recv_data: Any,
233
+ root: Optional[int] = 0,
234
+ tag: Optional[int] = 0,
235
+ ):
236
+ send_buf = convert_data_to_np_array(send_data)
237
+ recv_buf = convert_data_to_np_array(recv_data)
238
+ size = send_buf.size
239
+ dtype = send_buf.dtype
240
+ sendptr = send_buf.ctypes.data
241
+ recvptr = recv_buf.ctypes.data
242
+ gloo_type = TypeMappingGloo[dtype.type]
243
+ xp.gather(self._context, sendptr, recvptr, size, gloo_type, root, tag)
244
+
245
+ def allgather(self, send_data: Any, recv_data: Any, tag: Optional[int] = 0):
246
+ send_buf = convert_data_to_np_array(send_data)
247
+ recv_buf = convert_data_to_np_array(recv_data)
248
+ size = send_buf.size
249
+ dtype = send_buf.dtype
250
+ sendptr = send_buf.ctypes.data
251
+ recvptr = recv_buf.ctypes.data
252
+ gloo_type = TypeMappingGloo[dtype.type]
253
+ xp.allgather(self._context, sendptr, recvptr, size, gloo_type, tag)
254
+
255
+ def scatter(
256
+ self,
257
+ send_data: List[Any],
258
+ recv_data: Any,
259
+ root: Optional[int] = 0,
260
+ tag: Optional[int] = 0,
261
+ ):
262
+ send_bufs = [convert_data_to_np_array(d) for d in send_data]
263
+ recv_buf = convert_data_to_np_array(recv_data)
264
+ size = sum([d.size for d in send_bufs])
265
+ dtype = recv_buf.dtype
266
+ sendptrs = [d.ctypes.data for d in send_bufs]
267
+ recvptr = recv_buf.ctypes.data
268
+ gloo_type = TypeMappingGloo[dtype.type]
269
+ xp.scatter(self._context, sendptrs, recvptr, size, gloo_type, root, tag) # type: ignore
270
+
271
+ def reduce_scatter(
272
+ self,
273
+ send_data: Any,
274
+ recv_data: Any,
275
+ recv_elems: List[int],
276
+ op: CollectiveReduceOp = CollectiveReduceOp.SUM,
277
+ ): # pragma: no cover
278
+ send_buf = convert_data_to_np_array(send_data)
279
+ recv_buf = convert_data_to_np_array(recv_data)
280
+ sendptr = send_buf.ctypes.data
281
+ recvptr = recv_buf.ctypes.data
282
+ size = send_buf.size
283
+ dtype = send_buf.dtype
284
+ gloo_type = TypeMappingGloo[dtype.type]
285
+ xp.reduce_scatter(
286
+ self._context,
287
+ sendptr,
288
+ recvptr,
289
+ size,
290
+ recv_elems,
291
+ gloo_type,
292
+ ReduceOpMappingGloo[op],
293
+ )
294
+
295
+ def alltoall(self, send_data: Any, recv_data: Any, tag: Optional[int] = 0):
296
+ send_buf = convert_data_to_np_array(send_data)
297
+ recv_buf = convert_data_to_np_array(recv_data)
298
+ size = send_buf.size
299
+ dtype = send_buf.dtype
300
+ sendptr = send_buf.ctypes.data
301
+ recvptr = recv_buf.ctypes.data
302
+ gloo_type = TypeMappingGloo[dtype.type]
303
+ xp.all_to_all(self._context, sendptr, recvptr, size, gloo_type, tag)
304
+
305
+ def broadcast(
306
+ self,
307
+ send_data: Any,
308
+ recv_data: Any,
309
+ root: Optional[int] = 0,
310
+ tag: Optional[int] = 0,
311
+ ):
312
+ if send_data is not None:
313
+ send_buf = convert_data_to_np_array(send_data)
314
+ sendptr = send_buf.ctypes.data
315
+ else:
316
+ sendptr = None
317
+ recv_buf = convert_data_to_np_array(recv_data)
318
+ size = recv_buf.size
319
+ dtype = recv_buf.dtype
320
+ recvptr = recv_buf.ctypes.data
321
+ gloo_type = TypeMappingGloo[dtype.type]
322
+ xp.broadcast(
323
+ self._context,
324
+ recvptr if sendptr is None else sendptr,
325
+ recvptr,
326
+ size,
327
+ gloo_type,
328
+ root,
329
+ tag,
330
+ )
331
+
332
+
333
+ class ProcessGroupNCCL(ProcessGroup):
334
+ def __init__(
335
+ self,
336
+ ip: str,
337
+ rank: int,
338
+ device_id: int,
339
+ world_size: int,
340
+ group_name: Optional[str] = None,
341
+ pg_options: Optional[ProcessGroup.Options] = None,
342
+ ):
343
+ assert (
344
+ cupy != None
345
+ ), "cupy is required when creating a group using nccl as backend."
346
+ from cupy.cuda import nccl
347
+
348
+ super().__init__(rank, world_size, group_name, pg_options)
349
+ cupy.cuda.Device(device_id).use()
350
+ if _world._backend is None:
351
+ master_ip = (
352
+ pg_options.master_ip
353
+ if pg_options is not None
354
+ else os.environ.get(RENDEZVOUS_MASTER_IP_ENV_KEY, None)
355
+ )
356
+ master_port = (
357
+ pg_options.master_port
358
+ if pg_options is not None
359
+ else os.environ.get(RENDEZVOUS_MASTER_PORT_ENV_KEY, None)
360
+ )
361
+ if master_ip is None or master_port is None:
362
+ raise ValueError("Cannot find master ip or port for rendezvous")
363
+ store = TCPStore(world_size)
364
+ backend = XoscarNCCLBackend(
365
+ world_size, rank, store, master_ip, int(master_port)
366
+ )
367
+ _world._backend = backend
368
+ _world._nccl_size = world_size
369
+ self._is_world = True
370
+ self._backend = backend
371
+ else:
372
+ self._is_world = False
373
+ if rank == 0:
374
+ commId = nccl.get_unique_id()
375
+ ccid = cupy.array(commId)
376
+ for i in range(1, world_size):
377
+ _world._backend.send(ccid, i, None)
378
+ else:
379
+ commId = (int(i) for i in range(128))
380
+ commId = tuple(commId)
381
+ ccid = cupy.array(commId, dtype="int64")
382
+ _world._backend.recv(ccid, 0, None)
383
+ commId = tuple(ccid.tolist())
384
+ self._backend = nccl.NcclCommunicator(world_size, commId, rank)
385
+
386
+ def reduce(
387
+ self,
388
+ send_buf: Any,
389
+ recv_buf: Any,
390
+ op: CollectiveReduceOp = CollectiveReduceOp.SUM,
391
+ root: Optional[int] = 0,
392
+ stream: Optional[Any] = None,
393
+ ):
394
+ send_buf = convert_data_to_cp_array(send_buf)
395
+ recv_buf = convert_data_to_cp_array(recv_buf)
396
+ dtype = send_buf.dtype
397
+ stream = (
398
+ stream
399
+ if stream is not None and isinstance(stream, cupy.cuda.Stream)
400
+ else cupy.cuda.Stream.null
401
+ )
402
+ if self._is_world:
403
+ self._backend.reduce(
404
+ send_buf, recv_buf, root, ReduceOpMappingNCCLStr[op], stream
405
+ )
406
+ else:
407
+ self._backend.reduce(
408
+ send_buf.data.ptr,
409
+ recv_buf.data.ptr,
410
+ send_buf.size,
411
+ TypeMappingNCCL[dtype.type],
412
+ ReduceOpMappingNCCL[op],
413
+ root,
414
+ stream.ptr,
415
+ )
416
+
417
+ def allreduce(
418
+ self,
419
+ send_buf: Any,
420
+ recv_buf: Any,
421
+ op: CollectiveReduceOp = CollectiveReduceOp.SUM,
422
+ stream: Optional[Any] = None,
423
+ ):
424
+ send_buf = convert_data_to_cp_array(send_buf)
425
+ recv_buf = convert_data_to_cp_array(recv_buf)
426
+ dtype = send_buf.dtype
427
+ stream = (
428
+ stream
429
+ if stream is not None and isinstance(stream, cupy.cuda.Stream)
430
+ else cupy.cuda.Stream.null
431
+ )
432
+ if self._is_world:
433
+ self._backend.all_reduce(
434
+ send_buf, recv_buf, ReduceOpMappingNCCLStr[op], stream
435
+ )
436
+ else:
437
+ self._backend.allReduce(
438
+ send_buf.data.ptr,
439
+ recv_buf.data.ptr,
440
+ send_buf.size,
441
+ TypeMappingNCCL[dtype.type],
442
+ ReduceOpMappingNCCL[op],
443
+ stream.ptr,
444
+ )
445
+
446
+ def gather(
447
+ self,
448
+ send_buf: Any,
449
+ recv_buf: Any,
450
+ root: Optional[int] = 0,
451
+ stream: Optional[Any] = None,
452
+ ):
453
+ assert (
454
+ send_buf.size * self.world_size == recv_buf.size
455
+ ), "Send_size * world_number must be equal to recv_size"
456
+ send_buf = convert_data_to_cp_array(send_buf)
457
+ recv_buf = convert_data_to_cp_array(recv_buf)
458
+ dtype = send_buf.dtype
459
+ stream = (
460
+ stream
461
+ if stream is not None and isinstance(stream, cupy.cuda.Stream)
462
+ else cupy.cuda.Stream.null
463
+ )
464
+ if self._is_world:
465
+ self._backend.gather(send_buf, recv_buf, root, stream)
466
+ else:
467
+ if self.rank == root:
468
+ cupy.cuda.nccl.groupStart()
469
+ for peer in range(self.world_size):
470
+ if peer == self.rank:
471
+ recv_buf[peer : peer + 1] = send_buf.reshape(1, -1)
472
+ else:
473
+ self._backend.recv(
474
+ recv_buf[peer : peer + 1].data.ptr,
475
+ recv_buf[peer : peer + 1].size,
476
+ TypeMappingNCCL[dtype.type],
477
+ peer,
478
+ stream.ptr,
479
+ )
480
+ cupy.cuda.nccl.groupEnd()
481
+ else:
482
+ send_buf = send_buf.reshape(1, -1)
483
+ self._backend.send(
484
+ send_buf.data.ptr,
485
+ send_buf.size,
486
+ TypeMappingNCCL[dtype.type],
487
+ root,
488
+ stream.ptr,
489
+ )
490
+
491
+ def allgather(
492
+ self,
493
+ send_buf: Any,
494
+ recv_buf: Any,
495
+ stream: Optional[Any] = None,
496
+ ):
497
+ send_buf = convert_data_to_cp_array(send_buf)
498
+ recv_buf = convert_data_to_cp_array(recv_buf)
499
+ stream = (
500
+ stream
501
+ if stream is not None and isinstance(stream, cupy.cuda.Stream)
502
+ else cupy.cuda.Stream.null
503
+ )
504
+ dtype = send_buf.dtype
505
+ if self._is_world:
506
+ self._backend.all_gather(send_buf, recv_buf, send_buf.size, stream)
507
+ else:
508
+ self._backend.allGather(
509
+ send_buf.data.ptr,
510
+ recv_buf.data.ptr,
511
+ send_buf.size,
512
+ TypeMappingNCCL[dtype.type],
513
+ stream.ptr,
514
+ )
515
+
516
+ def scatter(
517
+ self,
518
+ send_buf: List[Any],
519
+ recv_buf: Any,
520
+ root: Optional[int] = 0,
521
+ stream: Optional[Any] = None,
522
+ ):
523
+ send_buf = [convert_data_to_cp_array(d) for d in send_buf]
524
+ recv_buf = convert_data_to_cp_array(recv_buf)
525
+ stream = (
526
+ stream
527
+ if stream is not None and isinstance(stream, cupy.cuda.Stream)
528
+ else cupy.cuda.Stream.null
529
+ )
530
+ if self._is_world:
531
+ send_buf = cupy.concatenate(send_buf).reshape(self.world_size, -1)
532
+ self._backend.scatter(send_buf, recv_buf, root, stream)
533
+ else:
534
+ if self.rank == root:
535
+ assert (
536
+ len(send_buf) == self.world_size
537
+ ), "Scatter size must be equal to the size of group"
538
+ cupy.cuda.nccl.groupStart()
539
+ for peer in range(self.world_size):
540
+ send_data = send_buf[peer]
541
+ if peer == root:
542
+ recv_buf[:] = send_data
543
+ else:
544
+ dtype = send_data.dtype
545
+ self._backend.send(
546
+ send_data.data.ptr,
547
+ send_data.size,
548
+ TypeMappingNCCL[dtype.type],
549
+ peer,
550
+ stream.ptr,
551
+ )
552
+ cupy.cuda.nccl.groupEnd()
553
+ else:
554
+ dtype = recv_buf.dtype
555
+ self._backend.recv(
556
+ recv_buf.data.ptr,
557
+ recv_buf.size,
558
+ TypeMappingNCCL[dtype.type],
559
+ root,
560
+ stream.ptr,
561
+ )
562
+
563
+ def reduce_scatter(
564
+ self,
565
+ send_buf: Any,
566
+ recv_buf: Any,
567
+ recv_elems: List[int],
568
+ op: CollectiveReduceOp = CollectiveReduceOp.SUM,
569
+ stream: Optional[Any] = None,
570
+ ):
571
+ send_buf = convert_data_to_cp_array(send_buf)
572
+ recv_buf = convert_data_to_cp_array(recv_buf)
573
+ dtype = send_buf.dtype
574
+ stream = (
575
+ stream
576
+ if stream is not None and isinstance(stream, cupy.cuda.Stream)
577
+ else cupy.cuda.Stream.null
578
+ )
579
+ if self._is_world:
580
+ self._backend.reduce_scatter(
581
+ send_buf,
582
+ recv_buf,
583
+ recv_elems[self.rank],
584
+ ReduceOpMappingNCCLStr[op],
585
+ stream,
586
+ )
587
+ else:
588
+ self._backend.reduceScatter(
589
+ send_buf.data.ptr,
590
+ recv_buf.data.ptr,
591
+ recv_elems[self.rank],
592
+ TypeMappingNCCL[dtype.type],
593
+ ReduceOpMappingNCCL[op],
594
+ stream.ptr,
595
+ )
596
+
597
+ def alltoall(
598
+ self,
599
+ send_buf: Any,
600
+ recv_buf: Any,
601
+ stream: Optional[Any] = None,
602
+ ):
603
+ assert (
604
+ self.world_size == send_buf.shape[0]
605
+ ), "The first dim of send data must be equal to world size."
606
+ assert (
607
+ recv_buf.size == send_buf.size
608
+ ), "The size of send data must be equal to the size of recv data."
609
+ send_buf = convert_data_to_cp_array(send_buf)
610
+ recv_buf = convert_data_to_cp_array(recv_buf)
611
+ dtype = send_buf.dtype
612
+ stream = (
613
+ stream
614
+ if stream is not None and isinstance(stream, cupy.cuda.Stream)
615
+ else cupy.cuda.Stream.null
616
+ )
617
+ if self._is_world:
618
+ self._backend.all_to_all(send_buf, recv_buf, stream)
619
+ else:
620
+ cupy.cuda.nccl.groupStart()
621
+ for peer in range(self.world_size):
622
+ if peer == self.rank:
623
+ recv_buf[peer : peer + 1] = send_buf[peer : peer + 1]
624
+ else:
625
+ if self.rank > peer:
626
+ self._backend.recv(
627
+ recv_buf[peer : peer + 1].data.ptr,
628
+ recv_buf[peer : peer + 1].size,
629
+ TypeMappingNCCL[dtype.type],
630
+ peer,
631
+ stream.ptr,
632
+ )
633
+ self._backend.send(
634
+ send_buf[peer : peer + 1].data.ptr,
635
+ send_buf[peer : peer + 1].size,
636
+ TypeMappingNCCL[dtype.type],
637
+ peer,
638
+ stream.ptr,
639
+ )
640
+ else:
641
+ self._backend.send(
642
+ send_buf[peer : peer + 1].data.ptr,
643
+ send_buf[peer : peer + 1].size,
644
+ TypeMappingNCCL[dtype.type],
645
+ peer,
646
+ stream.ptr,
647
+ )
648
+ self._backend.recv(
649
+ recv_buf[peer : peer + 1].data.ptr,
650
+ recv_buf[peer : peer + 1].size,
651
+ TypeMappingNCCL[dtype.type],
652
+ peer,
653
+ stream.ptr,
654
+ )
655
+ cupy.cuda.nccl.groupEnd()
656
+
657
+ def broadcast(
658
+ self,
659
+ send_buf: Any,
660
+ recv_buf: Any,
661
+ root: Optional[int] = 0,
662
+ stream: Optional[Any] = None,
663
+ ):
664
+ send_buf = convert_data_to_cp_array(send_buf)
665
+ recv_buf = convert_data_to_cp_array(recv_buf)
666
+ dtype = send_buf.dtype
667
+ stream = (
668
+ stream
669
+ if stream is not None and isinstance(stream, cupy.cuda.Stream)
670
+ else cupy.cuda.Stream.null
671
+ )
672
+ if self._is_world:
673
+ if self._rank == root:
674
+ self._backend.broadcast(send_buf, root, stream)
675
+ if recv_buf is not None and (recv_buf != send_buf).any():
676
+ recv_buf[:] = send_buf
677
+ else:
678
+ self._backend.broadcast(recv_buf, root, stream)
679
+ else:
680
+ self._backend.broadcast(
681
+ send_buf.data.ptr,
682
+ recv_buf.data.ptr,
683
+ send_buf.size,
684
+ TypeMappingNCCL[dtype.type],
685
+ root,
686
+ stream.ptr,
687
+ )