xoscar 0.3.1__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xoscar might be problematic. Click here for more details.

Files changed (80) hide show
  1. xoscar/__init__.py +60 -0
  2. xoscar/_utils.cpython-310-darwin.so +0 -0
  3. xoscar/_utils.pxd +36 -0
  4. xoscar/_utils.pyx +241 -0
  5. xoscar/_version.py +693 -0
  6. xoscar/aio/__init__.py +25 -0
  7. xoscar/aio/_threads.py +35 -0
  8. xoscar/aio/base.py +86 -0
  9. xoscar/aio/file.py +59 -0
  10. xoscar/aio/lru.py +228 -0
  11. xoscar/aio/parallelism.py +39 -0
  12. xoscar/api.py +493 -0
  13. xoscar/backend.py +67 -0
  14. xoscar/backends/__init__.py +14 -0
  15. xoscar/backends/allocate_strategy.py +160 -0
  16. xoscar/backends/communication/__init__.py +30 -0
  17. xoscar/backends/communication/base.py +315 -0
  18. xoscar/backends/communication/core.py +69 -0
  19. xoscar/backends/communication/dummy.py +242 -0
  20. xoscar/backends/communication/errors.py +20 -0
  21. xoscar/backends/communication/socket.py +375 -0
  22. xoscar/backends/communication/ucx.py +520 -0
  23. xoscar/backends/communication/utils.py +97 -0
  24. xoscar/backends/config.py +145 -0
  25. xoscar/backends/context.py +404 -0
  26. xoscar/backends/core.py +193 -0
  27. xoscar/backends/indigen/__init__.py +16 -0
  28. xoscar/backends/indigen/backend.py +51 -0
  29. xoscar/backends/indigen/driver.py +26 -0
  30. xoscar/backends/indigen/pool.py +469 -0
  31. xoscar/backends/message.cpython-310-darwin.so +0 -0
  32. xoscar/backends/message.pyx +591 -0
  33. xoscar/backends/pool.py +1593 -0
  34. xoscar/backends/router.py +207 -0
  35. xoscar/backends/test/__init__.py +16 -0
  36. xoscar/backends/test/backend.py +38 -0
  37. xoscar/backends/test/pool.py +208 -0
  38. xoscar/batch.py +256 -0
  39. xoscar/collective/__init__.py +27 -0
  40. xoscar/collective/common.py +102 -0
  41. xoscar/collective/core.py +737 -0
  42. xoscar/collective/process_group.py +687 -0
  43. xoscar/collective/utils.py +41 -0
  44. xoscar/collective/xoscar_pygloo.cpython-310-darwin.so +0 -0
  45. xoscar/constants.py +21 -0
  46. xoscar/context.cpython-310-darwin.so +0 -0
  47. xoscar/context.pxd +21 -0
  48. xoscar/context.pyx +368 -0
  49. xoscar/core.cpython-310-darwin.so +0 -0
  50. xoscar/core.pxd +50 -0
  51. xoscar/core.pyx +658 -0
  52. xoscar/debug.py +188 -0
  53. xoscar/driver.py +42 -0
  54. xoscar/errors.py +63 -0
  55. xoscar/libcpp.pxd +31 -0
  56. xoscar/metrics/__init__.py +21 -0
  57. xoscar/metrics/api.py +288 -0
  58. xoscar/metrics/backends/__init__.py +13 -0
  59. xoscar/metrics/backends/console/__init__.py +13 -0
  60. xoscar/metrics/backends/console/console_metric.py +82 -0
  61. xoscar/metrics/backends/metric.py +149 -0
  62. xoscar/metrics/backends/prometheus/__init__.py +13 -0
  63. xoscar/metrics/backends/prometheus/prometheus_metric.py +70 -0
  64. xoscar/nvutils.py +717 -0
  65. xoscar/profiling.py +260 -0
  66. xoscar/serialization/__init__.py +20 -0
  67. xoscar/serialization/aio.py +138 -0
  68. xoscar/serialization/core.cpython-310-darwin.so +0 -0
  69. xoscar/serialization/core.pxd +28 -0
  70. xoscar/serialization/core.pyx +954 -0
  71. xoscar/serialization/cuda.py +111 -0
  72. xoscar/serialization/exception.py +48 -0
  73. xoscar/serialization/numpy.py +82 -0
  74. xoscar/serialization/pyfury.py +37 -0
  75. xoscar/serialization/scipy.py +72 -0
  76. xoscar/utils.py +502 -0
  77. xoscar-0.3.1.dist-info/METADATA +225 -0
  78. xoscar-0.3.1.dist-info/RECORD +80 -0
  79. xoscar-0.3.1.dist-info/WHEEL +5 -0
  80. xoscar-0.3.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,520 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import asyncio
18
+ import concurrent.futures as futures
19
+ import functools
20
+ import logging
21
+ import os
22
+ import weakref
23
+ from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type
24
+
25
+ import cloudpickle
26
+ import numpy as np
27
+
28
+ from ...nvutils import get_cuda_context, get_index_and_uuid
29
+ from ...serialization import deserialize
30
+ from ...serialization.aio import BUFFER_SIZES_NAME, AioSerializer, get_header_length
31
+ from ...utils import classproperty, implements, is_cuda_buffer, lazy_import
32
+ from ..message import _MessageBase
33
+ from .base import Channel, ChannelType, Client, Server
34
+ from .core import register_client, register_server
35
+ from .errors import ChannelClosed
36
+
37
+ ucp = lazy_import("ucp")
38
+ numba_cuda = lazy_import("numba.cuda")
39
+ rmm = lazy_import("rmm")
40
+
41
+ _warning_suffix = (
42
+ "This is often the result of a CUDA-enabled library calling a CUDA runtime function before "
43
+ "spawning worker processes. Please make sure any such function calls don't happen "
44
+ "at import time or in the global scope of a program."
45
+ )
46
+
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ def synchronize_stream(stream: int = 0):
52
+ ctx = numba_cuda.current_context()
53
+ cu_stream = numba_cuda.driver.drvapi.cu_stream(stream)
54
+ stream = numba_cuda.driver.Stream(ctx, cu_stream, None)
55
+ stream.synchronize() # type: ignore
56
+
57
+
58
+ class UCXInitializer:
59
+ _inited = False
60
+
61
+ @staticmethod
62
+ def _get_options(ucx_config: dict) -> Tuple[dict, dict]:
63
+ """
64
+ Get options and envs from ucx options in oscar config
65
+ """
66
+ options = dict()
67
+ envs = dict()
68
+
69
+ # if any of the flags are set, as long as they are not Null/None,
70
+ # we assume we should configure basic TLS settings for UCX, otherwise we
71
+ # leave UCX to its default configuration
72
+ if any(ucx_config.get(name) for name in ["tcp", "nvlink", "infiniband"]):
73
+ if ucx_config.get("rdmacm"): # pragma: no cover
74
+ tls = "tcp"
75
+ tls_priority = "rdmacm"
76
+ else:
77
+ tls = "tcp"
78
+ tls_priority = "tcp"
79
+
80
+ # CUDA COPY can optionally be used with ucx -- we rely on the user
81
+ # to define when messages will include CUDA objects. Note:
82
+ # defining only the Infiniband flag will not enable cuda_copy
83
+ if any(
84
+ ucx_config.get(name) for name in ["nvlink", "cuda-copy"]
85
+ ): # pragma: no cover
86
+ tls += ",cuda_copy"
87
+
88
+ if ucx_config.get("infiniband"): # pragma: no cover
89
+ tls = "rc," + tls
90
+ if ucx_config.get("nvlink"): # pragma: no cover
91
+ tls += ",cuda_ipc"
92
+
93
+ options["TLS"] = tls
94
+ options["SOCKADDR_TLS_PRIORITY"] = tls_priority
95
+ elif "UCX_TLS" in os.environ: # pragma: no cover
96
+ options["TLS"] = os.environ["UCX_TLS"]
97
+
98
+ for k, v in ucx_config.get("environment", dict()).items(): # pragma: no cover
99
+ # {"some-name": value} is translated to {"UCX_SOME_NAME": value}
100
+ key = f'UCX_{"_".join(s.upper() for s in k.split("-"))}'
101
+ opt_key = key[4:]
102
+ if opt_key in options:
103
+ logger.warning(
104
+ f"Ignoring {k}={v} (key={key}) in ucx.environment, "
105
+ f"preferring {opt_key}={options[opt_key]} "
106
+ "from high level options"
107
+ )
108
+ elif key in os.environ:
109
+ # This is only info because setting UCX configuration via
110
+ # environment variables is a reasonably common approach
111
+ logger.info(
112
+ f"Ignoring {k}={v} (key={key}) in ucx.environment, "
113
+ f"preferring {key}={os.environ[key]} from external environment"
114
+ )
115
+ else:
116
+ envs[key] = v
117
+
118
+ return options, envs
119
+
120
+ @staticmethod
121
+ def init(ucx_config: dict):
122
+ if UCXInitializer._inited:
123
+ return
124
+
125
+ options, envs = UCXInitializer._get_options(ucx_config)
126
+
127
+ # We ensure the CUDA context is created before initializing UCX. This can't
128
+ # be safely handled externally because communications start before
129
+ # preload scripts run.
130
+ # Precedence:
131
+ # 1. external environment
132
+ # 2. ucx_config (high level settings passed to ucp.init)
133
+ # 3. ucx_environment (low level settings equivalent to environment variables)
134
+ ucx_tls = os.environ.get("UCX_TLS", options.get("TLS", envs.get("UCX_TLS", "")))
135
+ if (
136
+ ucx_config.get("create-cuda-contex") is True
137
+ # This is not foolproof, if UCX_TLS=all we might require CUDA
138
+ # depending on configuration of UCX, but this is better than
139
+ # nothing
140
+ or ("cuda" in ucx_tls and "^cuda" not in ucx_tls)
141
+ ):
142
+ if numba_cuda is None: # pragma: no cover
143
+ raise ImportError(
144
+ "CUDA support with UCX requires Numba for context management"
145
+ )
146
+
147
+ pre_existing_cuda_context = get_cuda_context()
148
+ if pre_existing_cuda_context.has_context:
149
+ dev = pre_existing_cuda_context.device_info
150
+ assert dev is not None
151
+ logger.warning(
152
+ f"A CUDA context for device {dev.device_index} ({str(dev.uuid)}) "
153
+ f"already exists on process ID {os.getpid()}. {_warning_suffix}"
154
+ )
155
+
156
+ numba_cuda.current_context()
157
+
158
+ cuda_context_created = get_cuda_context()
159
+ cuda_visible_device = get_index_and_uuid(
160
+ os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
161
+ )
162
+ if (
163
+ cuda_context_created.has_context
164
+ and cuda_context_created.device_info.uuid != cuda_visible_device.uuid # type: ignore
165
+ ): # pragma: no cover
166
+ cuda_context_created_dev = cuda_context_created.device_info
167
+ assert cuda_context_created_dev is not None
168
+ logger.warning(
169
+ f"Worker with process ID {os.getpid()} should have a CUDA context assigned to device "
170
+ f"{cuda_visible_device.device_index} ({str(cuda_visible_device.uuid)}), " # type: ignore
171
+ f"but instead the CUDA context is on device {cuda_context_created_dev.device_index} "
172
+ f"({str(cuda_context_created_dev.uuid)}). {_warning_suffix}"
173
+ )
174
+
175
+ original_environ = os.environ
176
+ new_environ = os.environ.copy()
177
+ new_environ.update(envs)
178
+ os.environ = new_environ # type: ignore
179
+ try:
180
+ ucp.init(options=options, env_takes_precedence=True)
181
+ finally:
182
+ os.environ = original_environ
183
+
184
+ UCXInitializer._inited = True
185
+
186
+ @staticmethod
187
+ def reset():
188
+ ucp.reset()
189
+ UCXInitializer._inited = False
190
+
191
+
192
+ class UCXChannel(Channel):
193
+ __slots__ = (
194
+ "ucp_endpoint",
195
+ "_closed",
196
+ "_has_close_callback",
197
+ "_send_lock",
198
+ "_recv_lock",
199
+ "__weakref__",
200
+ )
201
+
202
+ name = "ucx"
203
+
204
+ def __init__(
205
+ self,
206
+ ucp_endpoint: "ucp.Endpoint", # type: ignore
207
+ local_address: str | None = None,
208
+ dest_address: str | None = None,
209
+ compression: str | None = None,
210
+ ):
211
+ super().__init__(
212
+ local_address=local_address,
213
+ dest_address=dest_address,
214
+ compression=compression,
215
+ )
216
+ self.ucp_endpoint = ucp_endpoint
217
+
218
+ self._send_lock = asyncio.Lock()
219
+ self._recv_lock = asyncio.Lock()
220
+
221
+ # When the UCX endpoint closes or errors the registered callback
222
+ # is called.
223
+ if hasattr(self.ucp_endpoint, "set_close_callback"):
224
+ ref = weakref.ref(self)
225
+ self.ucp_endpoint.set_close_callback(
226
+ functools.partial(UCXChannel._close_channel, ref)
227
+ )
228
+ self._closed = False
229
+ self._has_close_callback = True
230
+ else: # pragma: no cover
231
+ self._has_close_callback = False
232
+
233
+ @staticmethod
234
+ def _close_channel(channel_ref: weakref.ReferenceType):
235
+ channel = channel_ref()
236
+ if channel is not None:
237
+ channel._closed = True
238
+
239
+ async def _serialize(self, message: Any) -> List[bytes]:
240
+ compress = self.compression or 0
241
+ serializer = AioSerializer(message, compress=compress)
242
+ return await serializer.run()
243
+
244
+ @property
245
+ @implements(Channel.type)
246
+ def type(self) -> int:
247
+ return ChannelType.remote
248
+
249
+ @implements(Channel.send)
250
+ async def send(self, message: Any):
251
+ if self.closed:
252
+ raise ChannelClosed("UCX Endpoint is closed, unable to send message")
253
+
254
+ buffers = await self._serialize(message)
255
+ return await self.send_buffers(buffers)
256
+
257
+ @implements(Channel.recv)
258
+ async def recv(self):
259
+ async with self._recv_lock:
260
+ try:
261
+ info_buffer = np.empty(11, dtype="u1").data
262
+ await self.ucp_endpoint.recv(info_buffer)
263
+ head_length = get_header_length(info_buffer)
264
+ header_buffer = np.empty(head_length, dtype="u1").data
265
+ await self.ucp_endpoint.recv(header_buffer)
266
+ header = cloudpickle.loads(header_buffer)
267
+
268
+ is_cuda_buffers = header[0].get("is_cuda_buffers")
269
+ buffer_sizes = header[0].pop(BUFFER_SIZES_NAME)
270
+
271
+ buffers = []
272
+ for is_cuda_buffer, buf_size in zip(is_cuda_buffers, buffer_sizes):
273
+ if buf_size == 0: # pragma: no cover
274
+ buffers.append(bytes())
275
+ elif is_cuda_buffer:
276
+ cuda_buffer = rmm.DeviceBuffer(size=buf_size)
277
+ await self.ucp_endpoint.recv(cuda_buffer)
278
+ buffers.append(cuda_buffer)
279
+ else:
280
+ buffer = np.empty(buf_size, dtype="u1").data
281
+ await self.ucp_endpoint.recv(buffer)
282
+ buffers.append(buffer)
283
+ except BaseException as e:
284
+ if not self._closed:
285
+ # In addition to UCX exceptions, may be CancelledError or another
286
+ # "low-level" exception. The only safe thing to do is to abort.
287
+ self.abort()
288
+ raise ChannelClosed(
289
+ f"Connection closed by writer.\nInner exception: {e!r}"
290
+ ) from e
291
+ else:
292
+ raise EOFError("Server closed already")
293
+ return deserialize(header, buffers)
294
+
295
+ async def send_buffers(self, buffers: list, meta: Optional[_MessageBase] = None):
296
+ try:
297
+ # It is necessary to first synchronize the default stream before start
298
+ # sending We synchronize the default stream because UCX is not
299
+ # stream-ordered and syncing the default stream will wait for other
300
+ # non-blocking CUDA streams. Note this is only sufficient if the memory
301
+ # being sent is not currently in use on non-blocking CUDA streams.
302
+ if any(is_cuda_buffer(buf) for buf in buffers):
303
+ # has GPU buffer
304
+ synchronize_stream(0)
305
+
306
+ meta_buffers = None
307
+ if meta:
308
+ meta_buffers = await self._serialize(meta)
309
+
310
+ async with self._send_lock:
311
+ if meta_buffers:
312
+ for buf in meta_buffers:
313
+ await self.ucp_endpoint.send(buf)
314
+ for buffer in buffers:
315
+ await self.ucp_endpoint.send(buffer)
316
+ except ucp.exceptions.UCXBaseException: # pragma: no cover
317
+ self.abort()
318
+ raise ChannelClosed("While writing, the connection was closed")
319
+
320
+ async def recv_buffers(self, buffers: list):
321
+ async with self._recv_lock:
322
+ try:
323
+ for buffer in buffers:
324
+ await self.ucp_endpoint.recv(buffer)
325
+ except BaseException as e: # pragma: no cover
326
+ if not self._closed:
327
+ # In addition to UCX exceptions, may be CancelledError or another
328
+ # "low-level" exception. The only safe thing to do is to abort.
329
+ self.abort()
330
+ raise ChannelClosed(
331
+ f"Connection closed by writer.\nInner exception: {e!r}"
332
+ ) from e
333
+ else:
334
+ raise EOFError("Server closed already")
335
+
336
+ def abort(self):
337
+ self._closed = True
338
+ if self.ucp_endpoint is not None:
339
+ self.ucp_endpoint.abort()
340
+ self.ucp_endpoint = None
341
+
342
+ @implements(Channel.close)
343
+ async def close(self):
344
+ self._closed = True
345
+ if self.ucp_endpoint is not None:
346
+ await self.ucp_endpoint.close()
347
+ # abort
348
+ self.ucp_endpoint.abort()
349
+ self.ucp_endpoint = None
350
+
351
+ @property
352
+ @implements(Channel.closed)
353
+ def closed(self):
354
+ if self._has_close_callback is None: # pragma: no cover
355
+ # The self._closed flag is separate from the endpoint's lifetime, even when
356
+ # the endpoint has closed or errored, there may be messages on its buffer
357
+ # still to be received, even though sending is not possible anymore.
358
+ return self._closed
359
+ else:
360
+ return self.ucp_endpoint is None
361
+
362
+
363
+ @register_server
364
+ class UCXServer(Server):
365
+ __slots__ = "host", "port", "_ucp_listener", "_channels", "_closed"
366
+
367
+ scheme = "ucx"
368
+
369
+ _ucp_listener: "ucp.Listener" # type: ignore
370
+ _channels: List[UCXChannel]
371
+
372
+ def __init__(
373
+ self,
374
+ host: str,
375
+ port: int,
376
+ ucp_listener: "ucp.Listener", # type: ignore
377
+ channel_handler: Callable[[Channel], Coroutine] | None = None,
378
+ ):
379
+ super().__init__(f"{UCXServer.scheme}://{host}:{port}", channel_handler)
380
+ self.host = host
381
+ self.port = port
382
+ self._ucp_listener = ucp_listener
383
+ self._channels = []
384
+ self._closed = asyncio.Event()
385
+
386
+ @classproperty
387
+ @implements(Server.client_type)
388
+ def client_type(self) -> Type["Client"]:
389
+ return UCXClient
390
+
391
+ @property
392
+ @implements(Server.channel_type)
393
+ def channel_type(self) -> int:
394
+ return ChannelType.remote
395
+
396
+ @staticmethod
397
+ async def create(config: Dict) -> "Server":
398
+ config = config.copy()
399
+ if "address" in config:
400
+ address = config.pop("address")
401
+ prefix = f"{UCXServer.scheme}://"
402
+ if address.startswith(prefix):
403
+ address = address[len(prefix) :]
404
+ host, port = address.split(":", 1)
405
+ port = int(port)
406
+ else:
407
+ host = config.pop("host")
408
+ port = int(config.pop("port"))
409
+ handle_channel = config.pop("handle_channel")
410
+
411
+ # init
412
+ UCXInitializer.init(config.get("ucx", dict()))
413
+
414
+ async def serve_forever(client_ucp_endpoint: "ucp.Endpoint"): # type: ignore
415
+ try:
416
+ await server.on_connected(
417
+ client_ucp_endpoint, local_address=server.address
418
+ )
419
+ except ChannelClosed: # pragma: no cover
420
+ logger.exception("Connection closed before handshake completed")
421
+ return
422
+
423
+ ucp_listener = ucp.create_listener(serve_forever, port=port)
424
+
425
+ # get port of the ucp listener if not specified
426
+ if not port:
427
+ port = ucp_listener.port
428
+
429
+ server = UCXServer(host, port, ucp_listener, channel_handler=handle_channel)
430
+ return server
431
+
432
+ @classmethod
433
+ def parse_config(cls, config: dict) -> dict:
434
+ return config
435
+
436
+ @implements(Server.start)
437
+ async def start(self):
438
+ pass
439
+
440
+ @implements(Server.join)
441
+ async def join(self, timeout=None):
442
+ wait_coro = self._closed.wait()
443
+ try:
444
+ await asyncio.wait_for(wait_coro, timeout=timeout)
445
+ except (futures.TimeoutError, asyncio.TimeoutError):
446
+ pass
447
+
448
+ @implements(Server.on_connected)
449
+ async def on_connected(self, *args, **kwargs):
450
+ (ucp_endpoint,) = args
451
+ local_address = kwargs.pop("local_address", None)
452
+ dest_address = kwargs.pop("dest_address", None)
453
+ if kwargs: # pragma: no cover
454
+ raise TypeError(
455
+ f"{type(self).__name__} got unexpected "
456
+ f'arguments: {",".join(kwargs)}'
457
+ )
458
+ channel = UCXChannel(
459
+ ucp_endpoint, local_address=local_address, dest_address=dest_address
460
+ )
461
+ self._channels.append(channel)
462
+ # handle over channel to some handlers
463
+ await self.channel_handler(channel)
464
+
465
+ @implements(Server.stop)
466
+ async def stop(self):
467
+ self._ucp_listener.close()
468
+ # close all channels
469
+ await asyncio.gather(
470
+ *(channel.close() for channel in self._channels if not channel.closed)
471
+ )
472
+ self._channels = []
473
+ self._ucp_listener = None
474
+ self._closed.set()
475
+
476
+ @property
477
+ @implements(Server.stopped)
478
+ def stopped(self) -> bool:
479
+ return self._ucp_listener is None
480
+
481
+
482
+ @register_client
483
+ class UCXClient(Client):
484
+ __slots__ = ()
485
+
486
+ scheme = UCXServer.scheme
487
+ channel: UCXChannel
488
+
489
+ @classmethod
490
+ def parse_config(cls, config: dict) -> dict:
491
+ return config
492
+
493
+ @staticmethod
494
+ @implements(Client.connect)
495
+ async def connect(
496
+ dest_address: str, local_address: str | None = None, **kwargs
497
+ ) -> "Client":
498
+ prefix = f"{UCXClient.scheme}://"
499
+ if dest_address.startswith(prefix):
500
+ dest_address = dest_address[len(prefix) :]
501
+ host, port_str = dest_address.split(":", 1)
502
+ port = int(port_str)
503
+ kwargs = kwargs.copy()
504
+ ucx_config = kwargs.pop("config", dict()).get("ucx", dict())
505
+ UCXInitializer.init(ucx_config)
506
+
507
+ try:
508
+ ucp_endpoint = await ucp.create_endpoint(host, port)
509
+ except ucp.exceptions.UCXBaseException as e: # pragma: no cover
510
+ raise ChannelClosed(
511
+ f"Connection closed before handshake completed, "
512
+ f"local address: {local_address}, dest address: {dest_address}"
513
+ ) from e
514
+ channel = UCXChannel(
515
+ ucp_endpoint, local_address=local_address, dest_address=dest_address
516
+ )
517
+ return UCXClient(local_address, dest_address, channel)
518
+
519
+ async def send_buffers(self, buffers: list, meta: _MessageBase):
520
+ return await self.channel.send_buffers(buffers, meta)
@@ -0,0 +1,97 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from asyncio import StreamReader, StreamWriter
17
+ from typing import Dict, List, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...serialization.aio import BUFFER_SIZES_NAME
22
+ from ...utils import lazy_import
23
+
24
+ cupy = lazy_import("cupy")
25
+ cudf = lazy_import("cudf")
26
+ rmm = lazy_import("rmm")
27
+
28
+ CUDA_CHUNK_SIZE = 16 * 1024**2
29
+
30
+
31
+ def _convert_to_cupy_ndarray(
32
+ cuda_buffer: Union["cupy.ndarray", "rmm.DeviceBuffer"] # type: ignore
33
+ ) -> "cupy.ndarray": # type: ignore
34
+ if isinstance(cuda_buffer, cupy.ndarray):
35
+ return cuda_buffer
36
+
37
+ size = cuda_buffer.nbytes
38
+ data = cuda_buffer.__cuda_array_interface__["data"][0]
39
+ memory = cupy.cuda.UnownedMemory(data, size, cuda_buffer)
40
+ ptr = cupy.cuda.MemoryPointer(memory, 0)
41
+ return cupy.ndarray(shape=size, dtype="u1", memptr=ptr)
42
+
43
+
44
+ def write_buffers(writer: StreamWriter, buffers: List):
45
+ def _write_cuda_buffer(cuda_buffer: Union["cupy.ndarray", "rmm.DeviceBuffer"]): # type: ignore
46
+ # convert cuda buffer to cupy ndarray
47
+ cuda_buffer = _convert_to_cupy_ndarray(cuda_buffer)
48
+
49
+ chunk_size = CUDA_CHUNK_SIZE
50
+ offset = 0
51
+ nbytes = buffer.nbytes
52
+ while offset < nbytes:
53
+ size = chunk_size if (offset + chunk_size) < nbytes else nbytes - offset
54
+ # slice on cupy ndarray
55
+ chunk_buffer = cuda_buffer[offset : offset + size]
56
+ # `get` will return numpy ndarray,
57
+ # write its data which is a memoryview into writer
58
+ writer.write(chunk_buffer.get().data)
59
+ offset += size
60
+
61
+ for buffer in buffers:
62
+ if hasattr(buffer, "__cuda_array_interface__"):
63
+ # GPU buffer
64
+ _write_cuda_buffer(buffer)
65
+ else:
66
+ writer.write(buffer)
67
+
68
+
69
+ async def read_buffers(header: Dict, reader: StreamReader):
70
+ is_cuda_buffers = header[0].get("is_cuda_buffers")
71
+ buffer_sizes = header[0].pop(BUFFER_SIZES_NAME)
72
+
73
+ buffers = []
74
+ for is_cuda_buffer, buf_size in zip(is_cuda_buffers, buffer_sizes):
75
+ if is_cuda_buffer: # pragma: no cover
76
+ if buf_size == 0:
77
+ # uniformly use rmm.DeviceBuffer for cuda's deserialization
78
+ buffers.append(rmm.DeviceBuffer(size=buf_size))
79
+ else:
80
+ buffer = rmm.DeviceBuffer(size=buf_size)
81
+ arr = _convert_to_cupy_ndarray(buffer)
82
+ offset = 0
83
+ chunk_size = CUDA_CHUNK_SIZE
84
+ while offset < buf_size:
85
+ read_size = (
86
+ chunk_size
87
+ if (offset + chunk_size) < buf_size
88
+ else buf_size - offset
89
+ )
90
+ content = await reader.readexactly(read_size)
91
+ chunk_arr = np.frombuffer(content, dtype="u1")
92
+ arr[offset : offset + len(content)].set(chunk_arr)
93
+ offset += read_size
94
+ buffers.append(buffer)
95
+ else:
96
+ buffers.append(await reader.readexactly(buf_size))
97
+ return buffers