xoscar 0.9.0__cp312-cp312-macosx_10_13_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xoscar/__init__.py +61 -0
- xoscar/_utils.cpython-312-darwin.so +0 -0
- xoscar/_utils.pxd +36 -0
- xoscar/_utils.pyx +246 -0
- xoscar/_version.py +693 -0
- xoscar/aio/__init__.py +16 -0
- xoscar/aio/base.py +86 -0
- xoscar/aio/file.py +59 -0
- xoscar/aio/lru.py +228 -0
- xoscar/aio/parallelism.py +39 -0
- xoscar/api.py +527 -0
- xoscar/backend.py +67 -0
- xoscar/backends/__init__.py +14 -0
- xoscar/backends/allocate_strategy.py +160 -0
- xoscar/backends/communication/__init__.py +30 -0
- xoscar/backends/communication/base.py +315 -0
- xoscar/backends/communication/core.py +69 -0
- xoscar/backends/communication/dummy.py +253 -0
- xoscar/backends/communication/errors.py +20 -0
- xoscar/backends/communication/socket.py +444 -0
- xoscar/backends/communication/ucx.py +538 -0
- xoscar/backends/communication/utils.py +97 -0
- xoscar/backends/config.py +157 -0
- xoscar/backends/context.py +437 -0
- xoscar/backends/core.py +352 -0
- xoscar/backends/indigen/__init__.py +16 -0
- xoscar/backends/indigen/__main__.py +19 -0
- xoscar/backends/indigen/backend.py +51 -0
- xoscar/backends/indigen/driver.py +26 -0
- xoscar/backends/indigen/fate_sharing.py +221 -0
- xoscar/backends/indigen/pool.py +515 -0
- xoscar/backends/indigen/shared_memory.py +548 -0
- xoscar/backends/message.cpython-312-darwin.so +0 -0
- xoscar/backends/message.pyi +255 -0
- xoscar/backends/message.pyx +646 -0
- xoscar/backends/pool.py +1630 -0
- xoscar/backends/router.py +285 -0
- xoscar/backends/test/__init__.py +16 -0
- xoscar/backends/test/backend.py +38 -0
- xoscar/backends/test/pool.py +233 -0
- xoscar/batch.py +256 -0
- xoscar/collective/__init__.py +27 -0
- xoscar/collective/backend/__init__.py +13 -0
- xoscar/collective/backend/nccl_backend.py +160 -0
- xoscar/collective/common.py +102 -0
- xoscar/collective/core.py +737 -0
- xoscar/collective/process_group.py +687 -0
- xoscar/collective/utils.py +41 -0
- xoscar/collective/xoscar_pygloo.cpython-312-darwin.so +0 -0
- xoscar/collective/xoscar_pygloo.pyi +239 -0
- xoscar/constants.py +23 -0
- xoscar/context.cpython-312-darwin.so +0 -0
- xoscar/context.pxd +21 -0
- xoscar/context.pyx +368 -0
- xoscar/core.cpython-312-darwin.so +0 -0
- xoscar/core.pxd +51 -0
- xoscar/core.pyx +664 -0
- xoscar/debug.py +188 -0
- xoscar/driver.py +42 -0
- xoscar/errors.py +63 -0
- xoscar/libcpp.pxd +31 -0
- xoscar/metrics/__init__.py +21 -0
- xoscar/metrics/api.py +288 -0
- xoscar/metrics/backends/__init__.py +13 -0
- xoscar/metrics/backends/console/__init__.py +13 -0
- xoscar/metrics/backends/console/console_metric.py +82 -0
- xoscar/metrics/backends/metric.py +149 -0
- xoscar/metrics/backends/prometheus/__init__.py +13 -0
- xoscar/metrics/backends/prometheus/prometheus_metric.py +70 -0
- xoscar/nvutils.py +717 -0
- xoscar/profiling.py +260 -0
- xoscar/serialization/__init__.py +20 -0
- xoscar/serialization/aio.py +141 -0
- xoscar/serialization/core.cpython-312-darwin.so +0 -0
- xoscar/serialization/core.pxd +28 -0
- xoscar/serialization/core.pyi +57 -0
- xoscar/serialization/core.pyx +944 -0
- xoscar/serialization/cuda.py +111 -0
- xoscar/serialization/exception.py +48 -0
- xoscar/serialization/mlx.py +67 -0
- xoscar/serialization/numpy.py +82 -0
- xoscar/serialization/pyfury.py +37 -0
- xoscar/serialization/scipy.py +72 -0
- xoscar/serialization/torch.py +180 -0
- xoscar/utils.py +522 -0
- xoscar/virtualenv/__init__.py +34 -0
- xoscar/virtualenv/core.py +268 -0
- xoscar/virtualenv/platform.py +56 -0
- xoscar/virtualenv/utils.py +100 -0
- xoscar/virtualenv/uv.py +321 -0
- xoscar-0.9.0.dist-info/METADATA +230 -0
- xoscar-0.9.0.dist-info/RECORD +94 -0
- xoscar-0.9.0.dist-info/WHEEL +6 -0
- xoscar-0.9.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,538 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import concurrent.futures as futures
|
|
19
|
+
import functools
|
|
20
|
+
import logging
|
|
21
|
+
import os
|
|
22
|
+
import weakref
|
|
23
|
+
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type
|
|
24
|
+
|
|
25
|
+
import cloudpickle
|
|
26
|
+
import numpy as np
|
|
27
|
+
|
|
28
|
+
from ...nvutils import get_cuda_context, get_index_and_uuid
|
|
29
|
+
from ...serialization import deserialize
|
|
30
|
+
from ...serialization.aio import BUFFER_SIZES_NAME, AioSerializer, get_header_length
|
|
31
|
+
from ...utils import classproperty, implements, is_cuda_buffer, is_v6_ip, lazy_import
|
|
32
|
+
from ..message import _MessageBase
|
|
33
|
+
from .base import Channel, ChannelType, Client, Server
|
|
34
|
+
from .core import register_client, register_server
|
|
35
|
+
from .errors import ChannelClosed
|
|
36
|
+
|
|
37
|
+
ucp = lazy_import("ucxx")
|
|
38
|
+
numba_cuda = lazy_import("numba.cuda")
|
|
39
|
+
rmm = lazy_import("rmm")
|
|
40
|
+
|
|
41
|
+
_warning_suffix = (
|
|
42
|
+
"This is often the result of a CUDA-enabled library calling a CUDA runtime function before "
|
|
43
|
+
"spawning worker processes. Please make sure any such function calls don't happen "
|
|
44
|
+
"at import time or in the global scope of a program."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
logger = logging.getLogger(__name__)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def synchronize_stream(stream: int = 0):
|
|
52
|
+
ctx = numba_cuda.current_context()
|
|
53
|
+
cu_stream = numba_cuda.driver.drvapi.cu_stream(stream)
|
|
54
|
+
stream = numba_cuda.driver.Stream(ctx, cu_stream, None)
|
|
55
|
+
stream.synchronize() # type: ignore
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class UCXInitializer:
|
|
59
|
+
_inited = False
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def _get_options(ucx_config: dict) -> Tuple[dict, dict]:
|
|
63
|
+
"""
|
|
64
|
+
Get options and envs from ucx options in oscar config
|
|
65
|
+
"""
|
|
66
|
+
options = dict()
|
|
67
|
+
envs = dict()
|
|
68
|
+
|
|
69
|
+
# if any of the flags are set, as long as they are not Null/None,
|
|
70
|
+
# we assume we should configure basic TLS settings for UCX, otherwise we
|
|
71
|
+
# leave UCX to its default configuration
|
|
72
|
+
if any(ucx_config.get(name) for name in ["tcp", "nvlink", "infiniband"]):
|
|
73
|
+
if ucx_config.get("rdmacm"): # pragma: no cover
|
|
74
|
+
tls = "tcp"
|
|
75
|
+
tls_priority = "rdmacm"
|
|
76
|
+
else:
|
|
77
|
+
tls = "tcp"
|
|
78
|
+
tls_priority = "tcp"
|
|
79
|
+
|
|
80
|
+
# CUDA COPY can optionally be used with ucx -- we rely on the user
|
|
81
|
+
# to define when messages will include CUDA objects. Note:
|
|
82
|
+
# defining only the Infiniband flag will not enable cuda_copy
|
|
83
|
+
if any(
|
|
84
|
+
ucx_config.get(name) for name in ["nvlink", "cuda-copy"]
|
|
85
|
+
): # pragma: no cover
|
|
86
|
+
tls += ",cuda_copy"
|
|
87
|
+
|
|
88
|
+
if ucx_config.get("infiniband"): # pragma: no cover
|
|
89
|
+
tls = "ib," + tls
|
|
90
|
+
if ucx_config.get("nvlink"): # pragma: no cover
|
|
91
|
+
tls += ",cuda_ipc"
|
|
92
|
+
|
|
93
|
+
options["TLS"] = tls
|
|
94
|
+
options["SOCKADDR_TLS_PRIORITY"] = tls_priority
|
|
95
|
+
elif "UCX_TLS" in os.environ: # pragma: no cover
|
|
96
|
+
options["TLS"] = os.environ["UCX_TLS"]
|
|
97
|
+
|
|
98
|
+
for k, v in ucx_config.get("environment", dict()).items(): # pragma: no cover
|
|
99
|
+
# {"some-name": value} is translated to {"UCX_SOME_NAME": value}
|
|
100
|
+
key = f'UCX_{"_".join(s.upper() for s in k.split("-"))}'
|
|
101
|
+
opt_key = key[4:]
|
|
102
|
+
if opt_key in options:
|
|
103
|
+
logger.warning(
|
|
104
|
+
f"Ignoring {k}={v} (key={key}) in ucx.environment, "
|
|
105
|
+
f"preferring {opt_key}={options[opt_key]} "
|
|
106
|
+
"from high level options"
|
|
107
|
+
)
|
|
108
|
+
elif key in os.environ:
|
|
109
|
+
# This is only info because setting UCX configuration via
|
|
110
|
+
# environment variables is a reasonably common approach
|
|
111
|
+
logger.info(
|
|
112
|
+
f"Ignoring {k}={v} (key={key}) in ucx.environment, "
|
|
113
|
+
f"preferring {key}={os.environ[key]} from external environment"
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
envs[key] = v
|
|
117
|
+
|
|
118
|
+
return options, envs
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def init(ucx_config: dict):
|
|
122
|
+
if UCXInitializer._inited:
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
options, envs = UCXInitializer._get_options(ucx_config)
|
|
126
|
+
|
|
127
|
+
# We ensure the CUDA context is created before initializing UCX. This can't
|
|
128
|
+
# be safely handled externally because communications start before
|
|
129
|
+
# preload scripts run.
|
|
130
|
+
# Precedence:
|
|
131
|
+
# 1. external environment
|
|
132
|
+
# 2. ucx_config (high level settings passed to ucp.init)
|
|
133
|
+
# 3. ucx_environment (low level settings equivalent to environment variables)
|
|
134
|
+
ucx_tls = os.environ.get("UCX_TLS", options.get("TLS", envs.get("UCX_TLS", "")))
|
|
135
|
+
if (
|
|
136
|
+
ucx_config.get("create-cuda-contex") is True
|
|
137
|
+
# This is not foolproof, if UCX_TLS=all we might require CUDA
|
|
138
|
+
# depending on configuration of UCX, but this is better than
|
|
139
|
+
# nothing
|
|
140
|
+
or ("cuda" in ucx_tls and "^cuda" not in ucx_tls)
|
|
141
|
+
):
|
|
142
|
+
if numba_cuda is None: # pragma: no cover
|
|
143
|
+
raise ImportError(
|
|
144
|
+
"CUDA support with UCX requires Numba for context management"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
pre_existing_cuda_context = get_cuda_context()
|
|
148
|
+
if pre_existing_cuda_context.has_context:
|
|
149
|
+
dev = pre_existing_cuda_context.device_info
|
|
150
|
+
assert dev is not None
|
|
151
|
+
logger.warning(
|
|
152
|
+
f"A CUDA context for device {dev.device_index} ({str(dev.uuid)}) "
|
|
153
|
+
f"already exists on process ID {os.getpid()}. {_warning_suffix}"
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
numba_cuda.current_context()
|
|
157
|
+
|
|
158
|
+
cuda_context_created = get_cuda_context()
|
|
159
|
+
cuda_visible_device = get_index_and_uuid(
|
|
160
|
+
os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
|
|
161
|
+
)
|
|
162
|
+
if (
|
|
163
|
+
cuda_context_created.has_context
|
|
164
|
+
and cuda_context_created.device_info.uuid != cuda_visible_device.uuid # type: ignore
|
|
165
|
+
): # pragma: no cover
|
|
166
|
+
cuda_context_created_dev = cuda_context_created.device_info
|
|
167
|
+
assert cuda_context_created_dev is not None
|
|
168
|
+
logger.warning(
|
|
169
|
+
f"Worker with process ID {os.getpid()} should have a CUDA context assigned to device "
|
|
170
|
+
f"{cuda_visible_device.device_index} ({str(cuda_visible_device.uuid)}), " # type: ignore
|
|
171
|
+
f"but instead the CUDA context is on device {cuda_context_created_dev.device_index} "
|
|
172
|
+
f"({str(cuda_context_created_dev.uuid)}). {_warning_suffix}"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
original_environ = os.environ
|
|
176
|
+
new_environ = os.environ.copy()
|
|
177
|
+
new_environ.update(envs)
|
|
178
|
+
os.environ = new_environ # type: ignore
|
|
179
|
+
try:
|
|
180
|
+
# let UCX determine the appropriate transports
|
|
181
|
+
ucp.init()
|
|
182
|
+
finally:
|
|
183
|
+
os.environ = original_environ
|
|
184
|
+
|
|
185
|
+
UCXInitializer._inited = True
|
|
186
|
+
|
|
187
|
+
@staticmethod
|
|
188
|
+
def reset():
|
|
189
|
+
ucp.reset()
|
|
190
|
+
UCXInitializer._inited = False
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class UCXChannel(Channel):
|
|
194
|
+
__slots__ = (
|
|
195
|
+
"ucp_endpoint",
|
|
196
|
+
"_closed",
|
|
197
|
+
"_has_close_callback",
|
|
198
|
+
"_send_lock",
|
|
199
|
+
"_recv_lock",
|
|
200
|
+
"__weakref__",
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
name = "ucx"
|
|
204
|
+
|
|
205
|
+
def __init__(
|
|
206
|
+
self,
|
|
207
|
+
ucp_endpoint: "ucp.Endpoint", # type: ignore
|
|
208
|
+
local_address: str | None = None,
|
|
209
|
+
dest_address: str | None = None,
|
|
210
|
+
compression: str | None = None,
|
|
211
|
+
):
|
|
212
|
+
super().__init__(
|
|
213
|
+
local_address=local_address,
|
|
214
|
+
dest_address=dest_address,
|
|
215
|
+
compression=compression,
|
|
216
|
+
)
|
|
217
|
+
self.ucp_endpoint = ucp_endpoint
|
|
218
|
+
|
|
219
|
+
self._send_lock = asyncio.Lock()
|
|
220
|
+
self._recv_lock = asyncio.Lock()
|
|
221
|
+
|
|
222
|
+
# When the UCX endpoint closes or errors the registered callback
|
|
223
|
+
# is called.
|
|
224
|
+
if hasattr(self.ucp_endpoint, "set_close_callback"):
|
|
225
|
+
ref = weakref.ref(self)
|
|
226
|
+
self.ucp_endpoint.set_close_callback(
|
|
227
|
+
functools.partial(UCXChannel._close_channel, ref)
|
|
228
|
+
)
|
|
229
|
+
self._closed = False
|
|
230
|
+
self._has_close_callback = True
|
|
231
|
+
else: # pragma: no cover
|
|
232
|
+
self._has_close_callback = False
|
|
233
|
+
|
|
234
|
+
@staticmethod
|
|
235
|
+
def _close_channel(channel_ref: weakref.ReferenceType):
|
|
236
|
+
channel = channel_ref()
|
|
237
|
+
if channel is not None:
|
|
238
|
+
channel._closed = True
|
|
239
|
+
|
|
240
|
+
async def _serialize(self, message: Any) -> List[bytes]:
|
|
241
|
+
compress = self.compression or 0
|
|
242
|
+
serializer = AioSerializer(message, compress=compress)
|
|
243
|
+
return await serializer.run()
|
|
244
|
+
|
|
245
|
+
@property
|
|
246
|
+
@implements(Channel.type)
|
|
247
|
+
def type(self) -> int:
|
|
248
|
+
return ChannelType.remote
|
|
249
|
+
|
|
250
|
+
@implements(Channel.send)
|
|
251
|
+
async def send(self, message: Any):
|
|
252
|
+
if self.closed:
|
|
253
|
+
raise ChannelClosed("UCX Endpoint is closed, unable to send message")
|
|
254
|
+
|
|
255
|
+
buffers = await self._serialize(message)
|
|
256
|
+
return await self.send_buffers(buffers)
|
|
257
|
+
|
|
258
|
+
@implements(Channel.recv)
|
|
259
|
+
async def recv(self):
|
|
260
|
+
async with self._recv_lock:
|
|
261
|
+
try:
|
|
262
|
+
info_buffer = np.empty(11, dtype="u1").data
|
|
263
|
+
await self.ucp_endpoint.recv(info_buffer)
|
|
264
|
+
head_length = get_header_length(info_buffer)
|
|
265
|
+
header_buffer = np.empty(head_length, dtype="u1").data
|
|
266
|
+
await self.ucp_endpoint.recv(header_buffer)
|
|
267
|
+
header = cloudpickle.loads(header_buffer)
|
|
268
|
+
|
|
269
|
+
is_cuda_buffers = header[0].get("is_cuda_buffers")
|
|
270
|
+
buffer_sizes = header[0].pop(BUFFER_SIZES_NAME)
|
|
271
|
+
|
|
272
|
+
buffers = []
|
|
273
|
+
for is_cuda_buffer, buf_size in zip(is_cuda_buffers, buffer_sizes):
|
|
274
|
+
if buf_size == 0: # pragma: no cover
|
|
275
|
+
buffers.append(bytes())
|
|
276
|
+
elif is_cuda_buffer:
|
|
277
|
+
cuda_buffer = rmm.DeviceBuffer(size=buf_size)
|
|
278
|
+
await self.ucp_endpoint.recv(cuda_buffer)
|
|
279
|
+
buffers.append(cuda_buffer)
|
|
280
|
+
else:
|
|
281
|
+
buffer = np.empty(buf_size, dtype="u1").data
|
|
282
|
+
await self.ucp_endpoint.recv(buffer)
|
|
283
|
+
buffers.append(buffer)
|
|
284
|
+
except BaseException as e:
|
|
285
|
+
if not self._closed:
|
|
286
|
+
# In addition to UCX exceptions, may be CancelledError or another
|
|
287
|
+
# "low-level" exception. The only safe thing to do is to abort.
|
|
288
|
+
self.abort()
|
|
289
|
+
raise ChannelClosed(
|
|
290
|
+
f"Connection closed by writer.\nInner exception: {e!r}"
|
|
291
|
+
) from e
|
|
292
|
+
else:
|
|
293
|
+
raise EOFError("Server closed already")
|
|
294
|
+
return deserialize(header, buffers)
|
|
295
|
+
|
|
296
|
+
async def send_buffers(self, buffers: list, meta: Optional[_MessageBase] = None):
|
|
297
|
+
try:
|
|
298
|
+
# It is necessary to first synchronize the default stream before start
|
|
299
|
+
# sending We synchronize the default stream because UCX is not
|
|
300
|
+
# stream-ordered and syncing the default stream will wait for other
|
|
301
|
+
# non-blocking CUDA streams. Note this is only sufficient if the memory
|
|
302
|
+
# being sent is not currently in use on non-blocking CUDA streams.
|
|
303
|
+
if any(is_cuda_buffer(buf) for buf in buffers):
|
|
304
|
+
# has GPU buffer
|
|
305
|
+
synchronize_stream(0)
|
|
306
|
+
|
|
307
|
+
meta_buffers = None
|
|
308
|
+
if meta:
|
|
309
|
+
meta_buffers = await self._serialize(meta)
|
|
310
|
+
|
|
311
|
+
async with self._send_lock:
|
|
312
|
+
if meta_buffers:
|
|
313
|
+
for buf in meta_buffers:
|
|
314
|
+
await self.ucp_endpoint.send(buf)
|
|
315
|
+
for buffer in buffers:
|
|
316
|
+
await self.ucp_endpoint.send(buffer)
|
|
317
|
+
except ucp.exceptions.UCXError: # pragma: no cover
|
|
318
|
+
self.abort()
|
|
319
|
+
raise ChannelClosed("While writing, the connection was closed")
|
|
320
|
+
|
|
321
|
+
async def recv_buffers(self, buffers: list):
|
|
322
|
+
async with self._recv_lock:
|
|
323
|
+
try:
|
|
324
|
+
for buffer in buffers:
|
|
325
|
+
await self.ucp_endpoint.recv(buffer)
|
|
326
|
+
except BaseException as e: # pragma: no cover
|
|
327
|
+
if not self._closed:
|
|
328
|
+
# In addition to UCX exceptions, may be CancelledError or another
|
|
329
|
+
# "low-level" exception. The only safe thing to do is to abort.
|
|
330
|
+
self.abort()
|
|
331
|
+
raise ChannelClosed(
|
|
332
|
+
f"Connection closed by writer.\nInner exception: {e!r}"
|
|
333
|
+
) from e
|
|
334
|
+
else:
|
|
335
|
+
raise EOFError("Server closed already")
|
|
336
|
+
|
|
337
|
+
def abort(self):
|
|
338
|
+
self._closed = True
|
|
339
|
+
if self.ucp_endpoint is not None:
|
|
340
|
+
self.ucp_endpoint.abort()
|
|
341
|
+
self.ucp_endpoint = None
|
|
342
|
+
|
|
343
|
+
@implements(Channel.close)
|
|
344
|
+
async def close(self):
|
|
345
|
+
self._closed = True
|
|
346
|
+
if self.ucp_endpoint is not None:
|
|
347
|
+
await self.ucp_endpoint.close()
|
|
348
|
+
# abort
|
|
349
|
+
self.ucp_endpoint.abort()
|
|
350
|
+
self.ucp_endpoint = None
|
|
351
|
+
|
|
352
|
+
@property
|
|
353
|
+
@implements(Channel.closed)
|
|
354
|
+
def closed(self):
|
|
355
|
+
if self._has_close_callback is None: # pragma: no cover
|
|
356
|
+
# The self._closed flag is separate from the endpoint's lifetime, even when
|
|
357
|
+
# the endpoint has closed or errored, there may be messages on its buffer
|
|
358
|
+
# still to be received, even though sending is not possible anymore.
|
|
359
|
+
return self._closed
|
|
360
|
+
else:
|
|
361
|
+
return self.ucp_endpoint is None
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
@register_server
|
|
365
|
+
class UCXServer(Server):
|
|
366
|
+
__slots__ = "host", "port", "_ucp_listener", "_channels", "_closed"
|
|
367
|
+
|
|
368
|
+
scheme = "ucx"
|
|
369
|
+
|
|
370
|
+
_ucp_listener: "ucp.Listener" # type: ignore
|
|
371
|
+
_channels: set[UCXChannel]
|
|
372
|
+
|
|
373
|
+
def __init__(
|
|
374
|
+
self,
|
|
375
|
+
host: str,
|
|
376
|
+
port: int,
|
|
377
|
+
ucp_listener: "ucp.Listener", # type: ignore
|
|
378
|
+
channel_handler: Callable[[Channel], Coroutine] | None = None,
|
|
379
|
+
):
|
|
380
|
+
super().__init__(f"{UCXServer.scheme}://{host}:{port}", channel_handler)
|
|
381
|
+
self.host = host
|
|
382
|
+
self.port = port
|
|
383
|
+
self._ucp_listener = ucp_listener
|
|
384
|
+
self._channels = set()
|
|
385
|
+
self._closed = asyncio.Event()
|
|
386
|
+
|
|
387
|
+
@classproperty
|
|
388
|
+
@implements(Server.client_type)
|
|
389
|
+
def client_type(self) -> Type["Client"]:
|
|
390
|
+
return UCXClient
|
|
391
|
+
|
|
392
|
+
@property
|
|
393
|
+
@implements(Server.channel_type)
|
|
394
|
+
def channel_type(self) -> int:
|
|
395
|
+
return ChannelType.remote
|
|
396
|
+
|
|
397
|
+
@staticmethod
|
|
398
|
+
async def create(config: Dict) -> "Server":
|
|
399
|
+
config = config.copy()
|
|
400
|
+
if "address" in config:
|
|
401
|
+
address = config.pop("address")
|
|
402
|
+
prefix = f"{UCXServer.scheme}://"
|
|
403
|
+
if address.startswith(prefix):
|
|
404
|
+
address = address[len(prefix) :]
|
|
405
|
+
host, port = address.rsplit(":", 1)
|
|
406
|
+
port = int(port)
|
|
407
|
+
else:
|
|
408
|
+
host = config.pop("host")
|
|
409
|
+
port = int(config.pop("port"))
|
|
410
|
+
_host = host
|
|
411
|
+
if config.pop("listen_elastic_ip", False):
|
|
412
|
+
# The Actor.address will be announce to client, and is not on our host,
|
|
413
|
+
# cannot actually listen on it,
|
|
414
|
+
# so we have to keep SocketServer.host untouched to make sure Actor.address not changed
|
|
415
|
+
if is_v6_ip(host):
|
|
416
|
+
_host = "::"
|
|
417
|
+
else:
|
|
418
|
+
_host = "0.0.0.0"
|
|
419
|
+
|
|
420
|
+
handle_channel = config.pop("handle_channel")
|
|
421
|
+
|
|
422
|
+
# init
|
|
423
|
+
UCXInitializer.init(config.get("ucx", dict()))
|
|
424
|
+
|
|
425
|
+
async def serve_forever(client_ucp_endpoint: "ucp.Endpoint"): # type: ignore
|
|
426
|
+
try:
|
|
427
|
+
await server.on_connected(
|
|
428
|
+
client_ucp_endpoint, local_address="%s:%d" % (_host, port)
|
|
429
|
+
)
|
|
430
|
+
except ChannelClosed: # pragma: no cover
|
|
431
|
+
logger.exception("Connection closed before handshake completed")
|
|
432
|
+
return
|
|
433
|
+
|
|
434
|
+
ucp_listener = ucp.create_listener(serve_forever, port=port)
|
|
435
|
+
|
|
436
|
+
# get port of the ucp listener if not specified
|
|
437
|
+
if not port:
|
|
438
|
+
port = ucp_listener.port
|
|
439
|
+
|
|
440
|
+
server = UCXServer(host, port, ucp_listener, channel_handler=handle_channel)
|
|
441
|
+
return server
|
|
442
|
+
|
|
443
|
+
@classmethod
|
|
444
|
+
def parse_config(cls, config: dict) -> dict:
|
|
445
|
+
return config
|
|
446
|
+
|
|
447
|
+
@implements(Server.start)
|
|
448
|
+
async def start(self):
|
|
449
|
+
pass
|
|
450
|
+
|
|
451
|
+
@implements(Server.join)
|
|
452
|
+
async def join(self, timeout=None):
|
|
453
|
+
wait_coro = self._closed.wait()
|
|
454
|
+
try:
|
|
455
|
+
await asyncio.wait_for(wait_coro, timeout=timeout)
|
|
456
|
+
except (futures.TimeoutError, asyncio.TimeoutError):
|
|
457
|
+
pass
|
|
458
|
+
|
|
459
|
+
@implements(Server.on_connected)
|
|
460
|
+
async def on_connected(self, *args, **kwargs):
|
|
461
|
+
(ucp_endpoint,) = args
|
|
462
|
+
local_address = kwargs.pop("local_address", None)
|
|
463
|
+
dest_address = kwargs.pop("dest_address", None)
|
|
464
|
+
if kwargs: # pragma: no cover
|
|
465
|
+
raise TypeError(
|
|
466
|
+
f"{type(self).__name__} got unexpected "
|
|
467
|
+
f'arguments: {",".join(kwargs)}'
|
|
468
|
+
)
|
|
469
|
+
channel = UCXChannel(
|
|
470
|
+
ucp_endpoint, local_address=local_address, dest_address=dest_address
|
|
471
|
+
)
|
|
472
|
+
self._channels.add(channel)
|
|
473
|
+
# handle over channel to some handlers
|
|
474
|
+
try:
|
|
475
|
+
await self.channel_handler(channel)
|
|
476
|
+
finally:
|
|
477
|
+
if not channel.closed:
|
|
478
|
+
await channel.close()
|
|
479
|
+
# Remove channel if channel exit
|
|
480
|
+
self._channels.discard(channel)
|
|
481
|
+
logger.debug("Channel exit: %s", channel.info)
|
|
482
|
+
|
|
483
|
+
@implements(Server.stop)
|
|
484
|
+
async def stop(self):
|
|
485
|
+
self._ucp_listener.close()
|
|
486
|
+
# close all channels
|
|
487
|
+
await asyncio.gather(
|
|
488
|
+
*(channel.close() for channel in self._channels if not channel.closed)
|
|
489
|
+
)
|
|
490
|
+
self._channels.clear()
|
|
491
|
+
self._ucp_listener = None
|
|
492
|
+
self._closed.set()
|
|
493
|
+
|
|
494
|
+
@property
|
|
495
|
+
@implements(Server.stopped)
|
|
496
|
+
def stopped(self) -> bool:
|
|
497
|
+
return self._ucp_listener is None
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
@register_client
|
|
501
|
+
class UCXClient(Client):
|
|
502
|
+
__slots__ = ()
|
|
503
|
+
|
|
504
|
+
scheme = UCXServer.scheme
|
|
505
|
+
channel: UCXChannel
|
|
506
|
+
|
|
507
|
+
@classmethod
|
|
508
|
+
def parse_config(cls, config: dict) -> dict:
|
|
509
|
+
return config
|
|
510
|
+
|
|
511
|
+
@staticmethod
|
|
512
|
+
@implements(Client.connect)
|
|
513
|
+
async def connect(
|
|
514
|
+
dest_address: str, local_address: str | None = None, **kwargs
|
|
515
|
+
) -> "Client":
|
|
516
|
+
prefix = f"{UCXClient.scheme}://"
|
|
517
|
+
if dest_address.startswith(prefix):
|
|
518
|
+
dest_address = dest_address[len(prefix) :]
|
|
519
|
+
host, port_str = dest_address.rsplit(":", 1)
|
|
520
|
+
port = int(port_str)
|
|
521
|
+
kwargs = kwargs.copy()
|
|
522
|
+
ucx_config = kwargs.pop("config", dict()).get("ucx", dict())
|
|
523
|
+
UCXInitializer.init(ucx_config)
|
|
524
|
+
|
|
525
|
+
try:
|
|
526
|
+
ucp_endpoint = await ucp.create_endpoint(host, port)
|
|
527
|
+
except ucp.exceptions.UCXError as e: # pragma: no cover
|
|
528
|
+
raise ChannelClosed(
|
|
529
|
+
f"Connection closed before handshake completed, "
|
|
530
|
+
f"local address: {local_address}, dest address: {dest_address}"
|
|
531
|
+
) from e
|
|
532
|
+
channel = UCXChannel(
|
|
533
|
+
ucp_endpoint, local_address=local_address, dest_address=dest_address
|
|
534
|
+
)
|
|
535
|
+
return UCXClient(local_address, dest_address, channel)
|
|
536
|
+
|
|
537
|
+
async def send_buffers(self, buffers: list, meta: _MessageBase):
|
|
538
|
+
return await self.channel.send_buffers(buffers, meta)
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from asyncio import StreamReader, StreamWriter
|
|
17
|
+
from typing import Dict, List, Union
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
from ...serialization.aio import BUFFER_SIZES_NAME
|
|
22
|
+
from ...utils import lazy_import
|
|
23
|
+
|
|
24
|
+
cupy = lazy_import("cupy")
|
|
25
|
+
cudf = lazy_import("cudf")
|
|
26
|
+
rmm = lazy_import("rmm")
|
|
27
|
+
|
|
28
|
+
CUDA_CHUNK_SIZE = 16 * 1024**2
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _convert_to_cupy_ndarray(
|
|
32
|
+
cuda_buffer: Union["cupy.ndarray", "rmm.DeviceBuffer"] # type: ignore
|
|
33
|
+
) -> "cupy.ndarray": # type: ignore
|
|
34
|
+
if isinstance(cuda_buffer, cupy.ndarray):
|
|
35
|
+
return cuda_buffer
|
|
36
|
+
|
|
37
|
+
size = cuda_buffer.nbytes
|
|
38
|
+
data = cuda_buffer.__cuda_array_interface__["data"][0]
|
|
39
|
+
memory = cupy.cuda.UnownedMemory(data, size, cuda_buffer)
|
|
40
|
+
ptr = cupy.cuda.MemoryPointer(memory, 0)
|
|
41
|
+
return cupy.ndarray(shape=size, dtype="u1", memptr=ptr)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def write_buffers(writer: StreamWriter, buffers: List):
|
|
45
|
+
def _write_cuda_buffer(cuda_buffer: Union["cupy.ndarray", "rmm.DeviceBuffer"]): # type: ignore
|
|
46
|
+
# convert cuda buffer to cupy ndarray
|
|
47
|
+
cuda_buffer = _convert_to_cupy_ndarray(cuda_buffer)
|
|
48
|
+
|
|
49
|
+
chunk_size = CUDA_CHUNK_SIZE
|
|
50
|
+
offset = 0
|
|
51
|
+
nbytes = buffer.nbytes
|
|
52
|
+
while offset < nbytes:
|
|
53
|
+
size = chunk_size if (offset + chunk_size) < nbytes else nbytes - offset
|
|
54
|
+
# slice on cupy ndarray
|
|
55
|
+
chunk_buffer = cuda_buffer[offset : offset + size]
|
|
56
|
+
# `get` will return numpy ndarray,
|
|
57
|
+
# write its data which is a memoryview into writer
|
|
58
|
+
writer.write(chunk_buffer.get().data)
|
|
59
|
+
offset += size
|
|
60
|
+
|
|
61
|
+
for buffer in buffers:
|
|
62
|
+
if hasattr(buffer, "__cuda_array_interface__"):
|
|
63
|
+
# GPU buffer
|
|
64
|
+
_write_cuda_buffer(buffer)
|
|
65
|
+
else:
|
|
66
|
+
writer.write(buffer)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
async def read_buffers(header: Dict, reader: StreamReader):
|
|
70
|
+
is_cuda_buffers = header[0].get("is_cuda_buffers")
|
|
71
|
+
buffer_sizes = header[0].pop(BUFFER_SIZES_NAME)
|
|
72
|
+
|
|
73
|
+
buffers = []
|
|
74
|
+
for is_cuda_buffer, buf_size in zip(is_cuda_buffers, buffer_sizes):
|
|
75
|
+
if is_cuda_buffer: # pragma: no cover
|
|
76
|
+
if buf_size == 0:
|
|
77
|
+
# uniformly use rmm.DeviceBuffer for cuda's deserialization
|
|
78
|
+
buffers.append(rmm.DeviceBuffer(size=buf_size))
|
|
79
|
+
else:
|
|
80
|
+
buffer = rmm.DeviceBuffer(size=buf_size)
|
|
81
|
+
arr = _convert_to_cupy_ndarray(buffer)
|
|
82
|
+
offset = 0
|
|
83
|
+
chunk_size = CUDA_CHUNK_SIZE
|
|
84
|
+
while offset < buf_size:
|
|
85
|
+
read_size = (
|
|
86
|
+
chunk_size
|
|
87
|
+
if (offset + chunk_size) < buf_size
|
|
88
|
+
else buf_size - offset
|
|
89
|
+
)
|
|
90
|
+
content = await reader.readexactly(read_size)
|
|
91
|
+
chunk_arr = np.frombuffer(content, dtype="u1")
|
|
92
|
+
arr[offset : offset + len(content)].set(chunk_arr)
|
|
93
|
+
offset += read_size
|
|
94
|
+
buffers.append(buffer)
|
|
95
|
+
else:
|
|
96
|
+
buffers.append(await reader.readexactly(buf_size))
|
|
97
|
+
return buffers
|