xoscar 0.9.0__cp312-cp312-macosx_10_13_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. xoscar/__init__.py +61 -0
  2. xoscar/_utils.cpython-312-darwin.so +0 -0
  3. xoscar/_utils.pxd +36 -0
  4. xoscar/_utils.pyx +246 -0
  5. xoscar/_version.py +693 -0
  6. xoscar/aio/__init__.py +16 -0
  7. xoscar/aio/base.py +86 -0
  8. xoscar/aio/file.py +59 -0
  9. xoscar/aio/lru.py +228 -0
  10. xoscar/aio/parallelism.py +39 -0
  11. xoscar/api.py +527 -0
  12. xoscar/backend.py +67 -0
  13. xoscar/backends/__init__.py +14 -0
  14. xoscar/backends/allocate_strategy.py +160 -0
  15. xoscar/backends/communication/__init__.py +30 -0
  16. xoscar/backends/communication/base.py +315 -0
  17. xoscar/backends/communication/core.py +69 -0
  18. xoscar/backends/communication/dummy.py +253 -0
  19. xoscar/backends/communication/errors.py +20 -0
  20. xoscar/backends/communication/socket.py +444 -0
  21. xoscar/backends/communication/ucx.py +538 -0
  22. xoscar/backends/communication/utils.py +97 -0
  23. xoscar/backends/config.py +157 -0
  24. xoscar/backends/context.py +437 -0
  25. xoscar/backends/core.py +352 -0
  26. xoscar/backends/indigen/__init__.py +16 -0
  27. xoscar/backends/indigen/__main__.py +19 -0
  28. xoscar/backends/indigen/backend.py +51 -0
  29. xoscar/backends/indigen/driver.py +26 -0
  30. xoscar/backends/indigen/fate_sharing.py +221 -0
  31. xoscar/backends/indigen/pool.py +515 -0
  32. xoscar/backends/indigen/shared_memory.py +548 -0
  33. xoscar/backends/message.cpython-312-darwin.so +0 -0
  34. xoscar/backends/message.pyi +255 -0
  35. xoscar/backends/message.pyx +646 -0
  36. xoscar/backends/pool.py +1630 -0
  37. xoscar/backends/router.py +285 -0
  38. xoscar/backends/test/__init__.py +16 -0
  39. xoscar/backends/test/backend.py +38 -0
  40. xoscar/backends/test/pool.py +233 -0
  41. xoscar/batch.py +256 -0
  42. xoscar/collective/__init__.py +27 -0
  43. xoscar/collective/backend/__init__.py +13 -0
  44. xoscar/collective/backend/nccl_backend.py +160 -0
  45. xoscar/collective/common.py +102 -0
  46. xoscar/collective/core.py +737 -0
  47. xoscar/collective/process_group.py +687 -0
  48. xoscar/collective/utils.py +41 -0
  49. xoscar/collective/xoscar_pygloo.cpython-312-darwin.so +0 -0
  50. xoscar/collective/xoscar_pygloo.pyi +239 -0
  51. xoscar/constants.py +23 -0
  52. xoscar/context.cpython-312-darwin.so +0 -0
  53. xoscar/context.pxd +21 -0
  54. xoscar/context.pyx +368 -0
  55. xoscar/core.cpython-312-darwin.so +0 -0
  56. xoscar/core.pxd +51 -0
  57. xoscar/core.pyx +664 -0
  58. xoscar/debug.py +188 -0
  59. xoscar/driver.py +42 -0
  60. xoscar/errors.py +63 -0
  61. xoscar/libcpp.pxd +31 -0
  62. xoscar/metrics/__init__.py +21 -0
  63. xoscar/metrics/api.py +288 -0
  64. xoscar/metrics/backends/__init__.py +13 -0
  65. xoscar/metrics/backends/console/__init__.py +13 -0
  66. xoscar/metrics/backends/console/console_metric.py +82 -0
  67. xoscar/metrics/backends/metric.py +149 -0
  68. xoscar/metrics/backends/prometheus/__init__.py +13 -0
  69. xoscar/metrics/backends/prometheus/prometheus_metric.py +70 -0
  70. xoscar/nvutils.py +717 -0
  71. xoscar/profiling.py +260 -0
  72. xoscar/serialization/__init__.py +20 -0
  73. xoscar/serialization/aio.py +141 -0
  74. xoscar/serialization/core.cpython-312-darwin.so +0 -0
  75. xoscar/serialization/core.pxd +28 -0
  76. xoscar/serialization/core.pyi +57 -0
  77. xoscar/serialization/core.pyx +944 -0
  78. xoscar/serialization/cuda.py +111 -0
  79. xoscar/serialization/exception.py +48 -0
  80. xoscar/serialization/mlx.py +67 -0
  81. xoscar/serialization/numpy.py +82 -0
  82. xoscar/serialization/pyfury.py +37 -0
  83. xoscar/serialization/scipy.py +72 -0
  84. xoscar/serialization/torch.py +180 -0
  85. xoscar/utils.py +522 -0
  86. xoscar/virtualenv/__init__.py +34 -0
  87. xoscar/virtualenv/core.py +268 -0
  88. xoscar/virtualenv/platform.py +56 -0
  89. xoscar/virtualenv/utils.py +100 -0
  90. xoscar/virtualenv/uv.py +321 -0
  91. xoscar-0.9.0.dist-info/METADATA +230 -0
  92. xoscar-0.9.0.dist-info/RECORD +94 -0
  93. xoscar-0.9.0.dist-info/WHEEL +6 -0
  94. xoscar-0.9.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,538 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import asyncio
18
+ import concurrent.futures as futures
19
+ import functools
20
+ import logging
21
+ import os
22
+ import weakref
23
+ from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type
24
+
25
+ import cloudpickle
26
+ import numpy as np
27
+
28
+ from ...nvutils import get_cuda_context, get_index_and_uuid
29
+ from ...serialization import deserialize
30
+ from ...serialization.aio import BUFFER_SIZES_NAME, AioSerializer, get_header_length
31
+ from ...utils import classproperty, implements, is_cuda_buffer, is_v6_ip, lazy_import
32
+ from ..message import _MessageBase
33
+ from .base import Channel, ChannelType, Client, Server
34
+ from .core import register_client, register_server
35
+ from .errors import ChannelClosed
36
+
37
+ ucp = lazy_import("ucxx")
38
+ numba_cuda = lazy_import("numba.cuda")
39
+ rmm = lazy_import("rmm")
40
+
41
+ _warning_suffix = (
42
+ "This is often the result of a CUDA-enabled library calling a CUDA runtime function before "
43
+ "spawning worker processes. Please make sure any such function calls don't happen "
44
+ "at import time or in the global scope of a program."
45
+ )
46
+
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ def synchronize_stream(stream: int = 0):
52
+ ctx = numba_cuda.current_context()
53
+ cu_stream = numba_cuda.driver.drvapi.cu_stream(stream)
54
+ stream = numba_cuda.driver.Stream(ctx, cu_stream, None)
55
+ stream.synchronize() # type: ignore
56
+
57
+
58
+ class UCXInitializer:
59
+ _inited = False
60
+
61
+ @staticmethod
62
+ def _get_options(ucx_config: dict) -> Tuple[dict, dict]:
63
+ """
64
+ Get options and envs from ucx options in oscar config
65
+ """
66
+ options = dict()
67
+ envs = dict()
68
+
69
+ # if any of the flags are set, as long as they are not Null/None,
70
+ # we assume we should configure basic TLS settings for UCX, otherwise we
71
+ # leave UCX to its default configuration
72
+ if any(ucx_config.get(name) for name in ["tcp", "nvlink", "infiniband"]):
73
+ if ucx_config.get("rdmacm"): # pragma: no cover
74
+ tls = "tcp"
75
+ tls_priority = "rdmacm"
76
+ else:
77
+ tls = "tcp"
78
+ tls_priority = "tcp"
79
+
80
+ # CUDA COPY can optionally be used with ucx -- we rely on the user
81
+ # to define when messages will include CUDA objects. Note:
82
+ # defining only the Infiniband flag will not enable cuda_copy
83
+ if any(
84
+ ucx_config.get(name) for name in ["nvlink", "cuda-copy"]
85
+ ): # pragma: no cover
86
+ tls += ",cuda_copy"
87
+
88
+ if ucx_config.get("infiniband"): # pragma: no cover
89
+ tls = "ib," + tls
90
+ if ucx_config.get("nvlink"): # pragma: no cover
91
+ tls += ",cuda_ipc"
92
+
93
+ options["TLS"] = tls
94
+ options["SOCKADDR_TLS_PRIORITY"] = tls_priority
95
+ elif "UCX_TLS" in os.environ: # pragma: no cover
96
+ options["TLS"] = os.environ["UCX_TLS"]
97
+
98
+ for k, v in ucx_config.get("environment", dict()).items(): # pragma: no cover
99
+ # {"some-name": value} is translated to {"UCX_SOME_NAME": value}
100
+ key = f'UCX_{"_".join(s.upper() for s in k.split("-"))}'
101
+ opt_key = key[4:]
102
+ if opt_key in options:
103
+ logger.warning(
104
+ f"Ignoring {k}={v} (key={key}) in ucx.environment, "
105
+ f"preferring {opt_key}={options[opt_key]} "
106
+ "from high level options"
107
+ )
108
+ elif key in os.environ:
109
+ # This is only info because setting UCX configuration via
110
+ # environment variables is a reasonably common approach
111
+ logger.info(
112
+ f"Ignoring {k}={v} (key={key}) in ucx.environment, "
113
+ f"preferring {key}={os.environ[key]} from external environment"
114
+ )
115
+ else:
116
+ envs[key] = v
117
+
118
+ return options, envs
119
+
120
+ @staticmethod
121
+ def init(ucx_config: dict):
122
+ if UCXInitializer._inited:
123
+ return
124
+
125
+ options, envs = UCXInitializer._get_options(ucx_config)
126
+
127
+ # We ensure the CUDA context is created before initializing UCX. This can't
128
+ # be safely handled externally because communications start before
129
+ # preload scripts run.
130
+ # Precedence:
131
+ # 1. external environment
132
+ # 2. ucx_config (high level settings passed to ucp.init)
133
+ # 3. ucx_environment (low level settings equivalent to environment variables)
134
+ ucx_tls = os.environ.get("UCX_TLS", options.get("TLS", envs.get("UCX_TLS", "")))
135
+ if (
136
+ ucx_config.get("create-cuda-contex") is True
137
+ # This is not foolproof, if UCX_TLS=all we might require CUDA
138
+ # depending on configuration of UCX, but this is better than
139
+ # nothing
140
+ or ("cuda" in ucx_tls and "^cuda" not in ucx_tls)
141
+ ):
142
+ if numba_cuda is None: # pragma: no cover
143
+ raise ImportError(
144
+ "CUDA support with UCX requires Numba for context management"
145
+ )
146
+
147
+ pre_existing_cuda_context = get_cuda_context()
148
+ if pre_existing_cuda_context.has_context:
149
+ dev = pre_existing_cuda_context.device_info
150
+ assert dev is not None
151
+ logger.warning(
152
+ f"A CUDA context for device {dev.device_index} ({str(dev.uuid)}) "
153
+ f"already exists on process ID {os.getpid()}. {_warning_suffix}"
154
+ )
155
+
156
+ numba_cuda.current_context()
157
+
158
+ cuda_context_created = get_cuda_context()
159
+ cuda_visible_device = get_index_and_uuid(
160
+ os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
161
+ )
162
+ if (
163
+ cuda_context_created.has_context
164
+ and cuda_context_created.device_info.uuid != cuda_visible_device.uuid # type: ignore
165
+ ): # pragma: no cover
166
+ cuda_context_created_dev = cuda_context_created.device_info
167
+ assert cuda_context_created_dev is not None
168
+ logger.warning(
169
+ f"Worker with process ID {os.getpid()} should have a CUDA context assigned to device "
170
+ f"{cuda_visible_device.device_index} ({str(cuda_visible_device.uuid)}), " # type: ignore
171
+ f"but instead the CUDA context is on device {cuda_context_created_dev.device_index} "
172
+ f"({str(cuda_context_created_dev.uuid)}). {_warning_suffix}"
173
+ )
174
+
175
+ original_environ = os.environ
176
+ new_environ = os.environ.copy()
177
+ new_environ.update(envs)
178
+ os.environ = new_environ # type: ignore
179
+ try:
180
+ # let UCX determine the appropriate transports
181
+ ucp.init()
182
+ finally:
183
+ os.environ = original_environ
184
+
185
+ UCXInitializer._inited = True
186
+
187
+ @staticmethod
188
+ def reset():
189
+ ucp.reset()
190
+ UCXInitializer._inited = False
191
+
192
+
193
+ class UCXChannel(Channel):
194
+ __slots__ = (
195
+ "ucp_endpoint",
196
+ "_closed",
197
+ "_has_close_callback",
198
+ "_send_lock",
199
+ "_recv_lock",
200
+ "__weakref__",
201
+ )
202
+
203
+ name = "ucx"
204
+
205
+ def __init__(
206
+ self,
207
+ ucp_endpoint: "ucp.Endpoint", # type: ignore
208
+ local_address: str | None = None,
209
+ dest_address: str | None = None,
210
+ compression: str | None = None,
211
+ ):
212
+ super().__init__(
213
+ local_address=local_address,
214
+ dest_address=dest_address,
215
+ compression=compression,
216
+ )
217
+ self.ucp_endpoint = ucp_endpoint
218
+
219
+ self._send_lock = asyncio.Lock()
220
+ self._recv_lock = asyncio.Lock()
221
+
222
+ # When the UCX endpoint closes or errors the registered callback
223
+ # is called.
224
+ if hasattr(self.ucp_endpoint, "set_close_callback"):
225
+ ref = weakref.ref(self)
226
+ self.ucp_endpoint.set_close_callback(
227
+ functools.partial(UCXChannel._close_channel, ref)
228
+ )
229
+ self._closed = False
230
+ self._has_close_callback = True
231
+ else: # pragma: no cover
232
+ self._has_close_callback = False
233
+
234
+ @staticmethod
235
+ def _close_channel(channel_ref: weakref.ReferenceType):
236
+ channel = channel_ref()
237
+ if channel is not None:
238
+ channel._closed = True
239
+
240
+ async def _serialize(self, message: Any) -> List[bytes]:
241
+ compress = self.compression or 0
242
+ serializer = AioSerializer(message, compress=compress)
243
+ return await serializer.run()
244
+
245
+ @property
246
+ @implements(Channel.type)
247
+ def type(self) -> int:
248
+ return ChannelType.remote
249
+
250
+ @implements(Channel.send)
251
+ async def send(self, message: Any):
252
+ if self.closed:
253
+ raise ChannelClosed("UCX Endpoint is closed, unable to send message")
254
+
255
+ buffers = await self._serialize(message)
256
+ return await self.send_buffers(buffers)
257
+
258
+ @implements(Channel.recv)
259
+ async def recv(self):
260
+ async with self._recv_lock:
261
+ try:
262
+ info_buffer = np.empty(11, dtype="u1").data
263
+ await self.ucp_endpoint.recv(info_buffer)
264
+ head_length = get_header_length(info_buffer)
265
+ header_buffer = np.empty(head_length, dtype="u1").data
266
+ await self.ucp_endpoint.recv(header_buffer)
267
+ header = cloudpickle.loads(header_buffer)
268
+
269
+ is_cuda_buffers = header[0].get("is_cuda_buffers")
270
+ buffer_sizes = header[0].pop(BUFFER_SIZES_NAME)
271
+
272
+ buffers = []
273
+ for is_cuda_buffer, buf_size in zip(is_cuda_buffers, buffer_sizes):
274
+ if buf_size == 0: # pragma: no cover
275
+ buffers.append(bytes())
276
+ elif is_cuda_buffer:
277
+ cuda_buffer = rmm.DeviceBuffer(size=buf_size)
278
+ await self.ucp_endpoint.recv(cuda_buffer)
279
+ buffers.append(cuda_buffer)
280
+ else:
281
+ buffer = np.empty(buf_size, dtype="u1").data
282
+ await self.ucp_endpoint.recv(buffer)
283
+ buffers.append(buffer)
284
+ except BaseException as e:
285
+ if not self._closed:
286
+ # In addition to UCX exceptions, may be CancelledError or another
287
+ # "low-level" exception. The only safe thing to do is to abort.
288
+ self.abort()
289
+ raise ChannelClosed(
290
+ f"Connection closed by writer.\nInner exception: {e!r}"
291
+ ) from e
292
+ else:
293
+ raise EOFError("Server closed already")
294
+ return deserialize(header, buffers)
295
+
296
+ async def send_buffers(self, buffers: list, meta: Optional[_MessageBase] = None):
297
+ try:
298
+ # It is necessary to first synchronize the default stream before start
299
+ # sending We synchronize the default stream because UCX is not
300
+ # stream-ordered and syncing the default stream will wait for other
301
+ # non-blocking CUDA streams. Note this is only sufficient if the memory
302
+ # being sent is not currently in use on non-blocking CUDA streams.
303
+ if any(is_cuda_buffer(buf) for buf in buffers):
304
+ # has GPU buffer
305
+ synchronize_stream(0)
306
+
307
+ meta_buffers = None
308
+ if meta:
309
+ meta_buffers = await self._serialize(meta)
310
+
311
+ async with self._send_lock:
312
+ if meta_buffers:
313
+ for buf in meta_buffers:
314
+ await self.ucp_endpoint.send(buf)
315
+ for buffer in buffers:
316
+ await self.ucp_endpoint.send(buffer)
317
+ except ucp.exceptions.UCXError: # pragma: no cover
318
+ self.abort()
319
+ raise ChannelClosed("While writing, the connection was closed")
320
+
321
+ async def recv_buffers(self, buffers: list):
322
+ async with self._recv_lock:
323
+ try:
324
+ for buffer in buffers:
325
+ await self.ucp_endpoint.recv(buffer)
326
+ except BaseException as e: # pragma: no cover
327
+ if not self._closed:
328
+ # In addition to UCX exceptions, may be CancelledError or another
329
+ # "low-level" exception. The only safe thing to do is to abort.
330
+ self.abort()
331
+ raise ChannelClosed(
332
+ f"Connection closed by writer.\nInner exception: {e!r}"
333
+ ) from e
334
+ else:
335
+ raise EOFError("Server closed already")
336
+
337
+ def abort(self):
338
+ self._closed = True
339
+ if self.ucp_endpoint is not None:
340
+ self.ucp_endpoint.abort()
341
+ self.ucp_endpoint = None
342
+
343
+ @implements(Channel.close)
344
+ async def close(self):
345
+ self._closed = True
346
+ if self.ucp_endpoint is not None:
347
+ await self.ucp_endpoint.close()
348
+ # abort
349
+ self.ucp_endpoint.abort()
350
+ self.ucp_endpoint = None
351
+
352
+ @property
353
+ @implements(Channel.closed)
354
+ def closed(self):
355
+ if self._has_close_callback is None: # pragma: no cover
356
+ # The self._closed flag is separate from the endpoint's lifetime, even when
357
+ # the endpoint has closed or errored, there may be messages on its buffer
358
+ # still to be received, even though sending is not possible anymore.
359
+ return self._closed
360
+ else:
361
+ return self.ucp_endpoint is None
362
+
363
+
364
+ @register_server
365
+ class UCXServer(Server):
366
+ __slots__ = "host", "port", "_ucp_listener", "_channels", "_closed"
367
+
368
+ scheme = "ucx"
369
+
370
+ _ucp_listener: "ucp.Listener" # type: ignore
371
+ _channels: set[UCXChannel]
372
+
373
+ def __init__(
374
+ self,
375
+ host: str,
376
+ port: int,
377
+ ucp_listener: "ucp.Listener", # type: ignore
378
+ channel_handler: Callable[[Channel], Coroutine] | None = None,
379
+ ):
380
+ super().__init__(f"{UCXServer.scheme}://{host}:{port}", channel_handler)
381
+ self.host = host
382
+ self.port = port
383
+ self._ucp_listener = ucp_listener
384
+ self._channels = set()
385
+ self._closed = asyncio.Event()
386
+
387
+ @classproperty
388
+ @implements(Server.client_type)
389
+ def client_type(self) -> Type["Client"]:
390
+ return UCXClient
391
+
392
+ @property
393
+ @implements(Server.channel_type)
394
+ def channel_type(self) -> int:
395
+ return ChannelType.remote
396
+
397
+ @staticmethod
398
+ async def create(config: Dict) -> "Server":
399
+ config = config.copy()
400
+ if "address" in config:
401
+ address = config.pop("address")
402
+ prefix = f"{UCXServer.scheme}://"
403
+ if address.startswith(prefix):
404
+ address = address[len(prefix) :]
405
+ host, port = address.rsplit(":", 1)
406
+ port = int(port)
407
+ else:
408
+ host = config.pop("host")
409
+ port = int(config.pop("port"))
410
+ _host = host
411
+ if config.pop("listen_elastic_ip", False):
412
+ # The Actor.address will be announce to client, and is not on our host,
413
+ # cannot actually listen on it,
414
+ # so we have to keep SocketServer.host untouched to make sure Actor.address not changed
415
+ if is_v6_ip(host):
416
+ _host = "::"
417
+ else:
418
+ _host = "0.0.0.0"
419
+
420
+ handle_channel = config.pop("handle_channel")
421
+
422
+ # init
423
+ UCXInitializer.init(config.get("ucx", dict()))
424
+
425
+ async def serve_forever(client_ucp_endpoint: "ucp.Endpoint"): # type: ignore
426
+ try:
427
+ await server.on_connected(
428
+ client_ucp_endpoint, local_address="%s:%d" % (_host, port)
429
+ )
430
+ except ChannelClosed: # pragma: no cover
431
+ logger.exception("Connection closed before handshake completed")
432
+ return
433
+
434
+ ucp_listener = ucp.create_listener(serve_forever, port=port)
435
+
436
+ # get port of the ucp listener if not specified
437
+ if not port:
438
+ port = ucp_listener.port
439
+
440
+ server = UCXServer(host, port, ucp_listener, channel_handler=handle_channel)
441
+ return server
442
+
443
+ @classmethod
444
+ def parse_config(cls, config: dict) -> dict:
445
+ return config
446
+
447
+ @implements(Server.start)
448
+ async def start(self):
449
+ pass
450
+
451
+ @implements(Server.join)
452
+ async def join(self, timeout=None):
453
+ wait_coro = self._closed.wait()
454
+ try:
455
+ await asyncio.wait_for(wait_coro, timeout=timeout)
456
+ except (futures.TimeoutError, asyncio.TimeoutError):
457
+ pass
458
+
459
+ @implements(Server.on_connected)
460
+ async def on_connected(self, *args, **kwargs):
461
+ (ucp_endpoint,) = args
462
+ local_address = kwargs.pop("local_address", None)
463
+ dest_address = kwargs.pop("dest_address", None)
464
+ if kwargs: # pragma: no cover
465
+ raise TypeError(
466
+ f"{type(self).__name__} got unexpected "
467
+ f'arguments: {",".join(kwargs)}'
468
+ )
469
+ channel = UCXChannel(
470
+ ucp_endpoint, local_address=local_address, dest_address=dest_address
471
+ )
472
+ self._channels.add(channel)
473
+ # handle over channel to some handlers
474
+ try:
475
+ await self.channel_handler(channel)
476
+ finally:
477
+ if not channel.closed:
478
+ await channel.close()
479
+ # Remove channel if channel exit
480
+ self._channels.discard(channel)
481
+ logger.debug("Channel exit: %s", channel.info)
482
+
483
+ @implements(Server.stop)
484
+ async def stop(self):
485
+ self._ucp_listener.close()
486
+ # close all channels
487
+ await asyncio.gather(
488
+ *(channel.close() for channel in self._channels if not channel.closed)
489
+ )
490
+ self._channels.clear()
491
+ self._ucp_listener = None
492
+ self._closed.set()
493
+
494
+ @property
495
+ @implements(Server.stopped)
496
+ def stopped(self) -> bool:
497
+ return self._ucp_listener is None
498
+
499
+
500
+ @register_client
501
+ class UCXClient(Client):
502
+ __slots__ = ()
503
+
504
+ scheme = UCXServer.scheme
505
+ channel: UCXChannel
506
+
507
+ @classmethod
508
+ def parse_config(cls, config: dict) -> dict:
509
+ return config
510
+
511
+ @staticmethod
512
+ @implements(Client.connect)
513
+ async def connect(
514
+ dest_address: str, local_address: str | None = None, **kwargs
515
+ ) -> "Client":
516
+ prefix = f"{UCXClient.scheme}://"
517
+ if dest_address.startswith(prefix):
518
+ dest_address = dest_address[len(prefix) :]
519
+ host, port_str = dest_address.rsplit(":", 1)
520
+ port = int(port_str)
521
+ kwargs = kwargs.copy()
522
+ ucx_config = kwargs.pop("config", dict()).get("ucx", dict())
523
+ UCXInitializer.init(ucx_config)
524
+
525
+ try:
526
+ ucp_endpoint = await ucp.create_endpoint(host, port)
527
+ except ucp.exceptions.UCXError as e: # pragma: no cover
528
+ raise ChannelClosed(
529
+ f"Connection closed before handshake completed, "
530
+ f"local address: {local_address}, dest address: {dest_address}"
531
+ ) from e
532
+ channel = UCXChannel(
533
+ ucp_endpoint, local_address=local_address, dest_address=dest_address
534
+ )
535
+ return UCXClient(local_address, dest_address, channel)
536
+
537
+ async def send_buffers(self, buffers: list, meta: _MessageBase):
538
+ return await self.channel.send_buffers(buffers, meta)
@@ -0,0 +1,97 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from asyncio import StreamReader, StreamWriter
17
+ from typing import Dict, List, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...serialization.aio import BUFFER_SIZES_NAME
22
+ from ...utils import lazy_import
23
+
24
+ cupy = lazy_import("cupy")
25
+ cudf = lazy_import("cudf")
26
+ rmm = lazy_import("rmm")
27
+
28
+ CUDA_CHUNK_SIZE = 16 * 1024**2
29
+
30
+
31
+ def _convert_to_cupy_ndarray(
32
+ cuda_buffer: Union["cupy.ndarray", "rmm.DeviceBuffer"] # type: ignore
33
+ ) -> "cupy.ndarray": # type: ignore
34
+ if isinstance(cuda_buffer, cupy.ndarray):
35
+ return cuda_buffer
36
+
37
+ size = cuda_buffer.nbytes
38
+ data = cuda_buffer.__cuda_array_interface__["data"][0]
39
+ memory = cupy.cuda.UnownedMemory(data, size, cuda_buffer)
40
+ ptr = cupy.cuda.MemoryPointer(memory, 0)
41
+ return cupy.ndarray(shape=size, dtype="u1", memptr=ptr)
42
+
43
+
44
+ def write_buffers(writer: StreamWriter, buffers: List):
45
+ def _write_cuda_buffer(cuda_buffer: Union["cupy.ndarray", "rmm.DeviceBuffer"]): # type: ignore
46
+ # convert cuda buffer to cupy ndarray
47
+ cuda_buffer = _convert_to_cupy_ndarray(cuda_buffer)
48
+
49
+ chunk_size = CUDA_CHUNK_SIZE
50
+ offset = 0
51
+ nbytes = buffer.nbytes
52
+ while offset < nbytes:
53
+ size = chunk_size if (offset + chunk_size) < nbytes else nbytes - offset
54
+ # slice on cupy ndarray
55
+ chunk_buffer = cuda_buffer[offset : offset + size]
56
+ # `get` will return numpy ndarray,
57
+ # write its data which is a memoryview into writer
58
+ writer.write(chunk_buffer.get().data)
59
+ offset += size
60
+
61
+ for buffer in buffers:
62
+ if hasattr(buffer, "__cuda_array_interface__"):
63
+ # GPU buffer
64
+ _write_cuda_buffer(buffer)
65
+ else:
66
+ writer.write(buffer)
67
+
68
+
69
+ async def read_buffers(header: Dict, reader: StreamReader):
70
+ is_cuda_buffers = header[0].get("is_cuda_buffers")
71
+ buffer_sizes = header[0].pop(BUFFER_SIZES_NAME)
72
+
73
+ buffers = []
74
+ for is_cuda_buffer, buf_size in zip(is_cuda_buffers, buffer_sizes):
75
+ if is_cuda_buffer: # pragma: no cover
76
+ if buf_size == 0:
77
+ # uniformly use rmm.DeviceBuffer for cuda's deserialization
78
+ buffers.append(rmm.DeviceBuffer(size=buf_size))
79
+ else:
80
+ buffer = rmm.DeviceBuffer(size=buf_size)
81
+ arr = _convert_to_cupy_ndarray(buffer)
82
+ offset = 0
83
+ chunk_size = CUDA_CHUNK_SIZE
84
+ while offset < buf_size:
85
+ read_size = (
86
+ chunk_size
87
+ if (offset + chunk_size) < buf_size
88
+ else buf_size - offset
89
+ )
90
+ content = await reader.readexactly(read_size)
91
+ chunk_arr = np.frombuffer(content, dtype="u1")
92
+ arr[offset : offset + len(content)].set(chunk_arr)
93
+ offset += read_size
94
+ buffers.append(buffer)
95
+ else:
96
+ buffers.append(await reader.readexactly(buf_size))
97
+ return buffers