xoscar 0.4.0__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xoscar might be problematic. Click here for more details.

Files changed (82) hide show
  1. xoscar/__init__.py +60 -0
  2. xoscar/_utils.cpython-312-darwin.so +0 -0
  3. xoscar/_utils.pxd +36 -0
  4. xoscar/_utils.pyx +241 -0
  5. xoscar/_version.py +693 -0
  6. xoscar/aio/__init__.py +16 -0
  7. xoscar/aio/base.py +86 -0
  8. xoscar/aio/file.py +59 -0
  9. xoscar/aio/lru.py +228 -0
  10. xoscar/aio/parallelism.py +39 -0
  11. xoscar/api.py +493 -0
  12. xoscar/backend.py +67 -0
  13. xoscar/backends/__init__.py +14 -0
  14. xoscar/backends/allocate_strategy.py +160 -0
  15. xoscar/backends/communication/__init__.py +30 -0
  16. xoscar/backends/communication/base.py +315 -0
  17. xoscar/backends/communication/core.py +69 -0
  18. xoscar/backends/communication/dummy.py +242 -0
  19. xoscar/backends/communication/errors.py +20 -0
  20. xoscar/backends/communication/socket.py +414 -0
  21. xoscar/backends/communication/ucx.py +531 -0
  22. xoscar/backends/communication/utils.py +97 -0
  23. xoscar/backends/config.py +145 -0
  24. xoscar/backends/context.py +404 -0
  25. xoscar/backends/core.py +193 -0
  26. xoscar/backends/indigen/__init__.py +16 -0
  27. xoscar/backends/indigen/backend.py +51 -0
  28. xoscar/backends/indigen/driver.py +26 -0
  29. xoscar/backends/indigen/pool.py +469 -0
  30. xoscar/backends/message.cpython-312-darwin.so +0 -0
  31. xoscar/backends/message.pyi +239 -0
  32. xoscar/backends/message.pyx +599 -0
  33. xoscar/backends/pool.py +1596 -0
  34. xoscar/backends/router.py +207 -0
  35. xoscar/backends/test/__init__.py +16 -0
  36. xoscar/backends/test/backend.py +38 -0
  37. xoscar/backends/test/pool.py +208 -0
  38. xoscar/batch.py +256 -0
  39. xoscar/collective/__init__.py +27 -0
  40. xoscar/collective/common.py +102 -0
  41. xoscar/collective/core.py +737 -0
  42. xoscar/collective/process_group.py +687 -0
  43. xoscar/collective/utils.py +41 -0
  44. xoscar/collective/xoscar_pygloo.cpython-312-darwin.so +0 -0
  45. xoscar/collective/xoscar_pygloo.pyi +239 -0
  46. xoscar/constants.py +21 -0
  47. xoscar/context.cpython-312-darwin.so +0 -0
  48. xoscar/context.pxd +21 -0
  49. xoscar/context.pyx +368 -0
  50. xoscar/core.cpython-312-darwin.so +0 -0
  51. xoscar/core.pxd +50 -0
  52. xoscar/core.pyx +658 -0
  53. xoscar/debug.py +188 -0
  54. xoscar/driver.py +42 -0
  55. xoscar/errors.py +63 -0
  56. xoscar/libcpp.pxd +31 -0
  57. xoscar/metrics/__init__.py +21 -0
  58. xoscar/metrics/api.py +288 -0
  59. xoscar/metrics/backends/__init__.py +13 -0
  60. xoscar/metrics/backends/console/__init__.py +13 -0
  61. xoscar/metrics/backends/console/console_metric.py +82 -0
  62. xoscar/metrics/backends/metric.py +149 -0
  63. xoscar/metrics/backends/prometheus/__init__.py +13 -0
  64. xoscar/metrics/backends/prometheus/prometheus_metric.py +70 -0
  65. xoscar/nvutils.py +717 -0
  66. xoscar/profiling.py +260 -0
  67. xoscar/serialization/__init__.py +20 -0
  68. xoscar/serialization/aio.py +138 -0
  69. xoscar/serialization/core.cpython-312-darwin.so +0 -0
  70. xoscar/serialization/core.pxd +28 -0
  71. xoscar/serialization/core.pyi +57 -0
  72. xoscar/serialization/core.pyx +944 -0
  73. xoscar/serialization/cuda.py +111 -0
  74. xoscar/serialization/exception.py +48 -0
  75. xoscar/serialization/numpy.py +82 -0
  76. xoscar/serialization/pyfury.py +37 -0
  77. xoscar/serialization/scipy.py +72 -0
  78. xoscar/utils.py +517 -0
  79. xoscar-0.4.0.dist-info/METADATA +223 -0
  80. xoscar-0.4.0.dist-info/RECORD +82 -0
  81. xoscar-0.4.0.dist-info/WHEEL +5 -0
  82. xoscar-0.4.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,242 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import concurrent.futures as futures
20
+ import weakref
21
+ from typing import Any, Callable, Coroutine, Dict, Type
22
+ from urllib.parse import urlparse
23
+
24
+ from ...errors import ServerClosed
25
+ from ...utils import abc_type_require_weakref_slot, classproperty, implements
26
+ from .base import Channel, ChannelType, Client, Server
27
+ from .core import register_client, register_server
28
+ from .errors import ChannelClosed
29
+
30
+ DEFAULT_DUMMY_ADDRESS = "dummy://0"
31
+
32
+
33
+ class DummyChannel(Channel):
34
+ """
35
+ Channel for communications in same process.
36
+ """
37
+
38
+ __slots__ = "_in_queue", "_out_queue", "_closed"
39
+
40
+ name = "dummy"
41
+
42
+ def __init__(
43
+ self,
44
+ in_queue: asyncio.Queue,
45
+ out_queue: asyncio.Queue,
46
+ closed: asyncio.Event,
47
+ local_address: str | None = None,
48
+ dest_address: str | None = None,
49
+ compression: str | None = None,
50
+ ):
51
+ super().__init__(
52
+ local_address=local_address,
53
+ dest_address=dest_address,
54
+ compression=compression,
55
+ )
56
+ self._in_queue = in_queue
57
+ self._out_queue = out_queue
58
+ self._closed = closed
59
+
60
+ @property
61
+ @implements(Channel.type)
62
+ def type(self) -> int:
63
+ return ChannelType.local
64
+
65
+ @implements(Channel.send)
66
+ async def send(self, message: Any):
67
+ if self._closed.is_set(): # pragma: no cover
68
+ raise ChannelClosed("Channel already closed, cannot send message")
69
+ # put message directly into queue
70
+ self._out_queue.put_nowait(message)
71
+
72
+ @implements(Channel.recv)
73
+ async def recv(self):
74
+ if self._closed.is_set(): # pragma: no cover
75
+ raise ChannelClosed("Channel already closed, cannot write message")
76
+ try:
77
+ return await self._in_queue.get()
78
+ except RuntimeError:
79
+ if self._closed.is_set():
80
+ pass
81
+
82
+ @implements(Channel.close)
83
+ async def close(self):
84
+ self._closed.set()
85
+
86
+ @property
87
+ @implements(Channel.closed)
88
+ def closed(self) -> bool:
89
+ return self._closed.is_set()
90
+
91
+
92
+ @register_server
93
+ class DummyServer(Server):
94
+ __slots__ = (
95
+ ("_closed", "_channels", "_tasks") + ("__weakref__",)
96
+ if abc_type_require_weakref_slot
97
+ else tuple()
98
+ )
99
+
100
+ _address_to_instances: weakref.WeakValueDictionary[str, "DummyServer"] = (
101
+ weakref.WeakValueDictionary()
102
+ )
103
+ _channels: list[ChannelType]
104
+ _tasks: list[asyncio.Task]
105
+ scheme: str | None = "dummy"
106
+
107
+ def __init__(
108
+ self,
109
+ address: str,
110
+ channel_handler: Callable[[Channel], Coroutine] | None = None,
111
+ ):
112
+ super().__init__(address, channel_handler)
113
+ self._closed = asyncio.Event()
114
+ self._channels = []
115
+ self._tasks = []
116
+
117
+ @classmethod
118
+ def get_instance(cls, address: str):
119
+ return cls._address_to_instances[address]
120
+
121
+ @classproperty
122
+ @implements(Server.client_type)
123
+ def client_type(self) -> Type["Client"]:
124
+ return DummyClient
125
+
126
+ @property
127
+ @implements(Server.channel_type)
128
+ def channel_type(self) -> int:
129
+ return ChannelType.local
130
+
131
+ @staticmethod
132
+ @implements(Server.create)
133
+ async def create(config: Dict) -> "DummyServer":
134
+ config = config.copy()
135
+ address = config.pop("address", DEFAULT_DUMMY_ADDRESS)
136
+ handle_channel = config.pop("handle_channel")
137
+ if urlparse(address).scheme != DummyServer.scheme: # pragma: no cover
138
+ raise ValueError(
139
+ f"Address for DummyServer "
140
+ f'should be starts with "dummy://", '
141
+ f"got {address}"
142
+ )
143
+ if config: # pragma: no cover
144
+ raise TypeError(
145
+ f"Creating DummyServer got unexpected " f'arguments: {",".join(config)}'
146
+ )
147
+ try:
148
+ server = DummyServer.get_instance(address)
149
+ if server.stopped:
150
+ raise KeyError("server closed")
151
+ except KeyError:
152
+ server = DummyServer(address, handle_channel)
153
+ DummyServer._address_to_instances[address] = server
154
+ return server
155
+
156
+ @implements(Server.start)
157
+ async def start(self):
158
+ # nothing needs to do for dummy server
159
+ pass
160
+
161
+ @implements(Server.join)
162
+ async def join(self, timeout=None):
163
+ wait_coro = self._closed.wait()
164
+ try:
165
+ await asyncio.wait_for(wait_coro, timeout=timeout)
166
+ except (futures.TimeoutError, asyncio.TimeoutError):
167
+ pass
168
+
169
+ @implements(Server.on_connected)
170
+ async def on_connected(self, *args, **kwargs):
171
+ if self._closed.is_set(): # pragma: no cover
172
+ raise ServerClosed("Dummy server already closed")
173
+
174
+ channel = args[0]
175
+ assert isinstance(channel, DummyChannel)
176
+ if kwargs: # pragma: no cover
177
+ raise TypeError(
178
+ f"{type(self).__name__} got unexpected "
179
+ f'arguments: {",".join(kwargs)}'
180
+ )
181
+ self._channels.append(channel)
182
+ await self.channel_handler(channel)
183
+
184
+ @implements(Server.stop)
185
+ async def stop(self):
186
+ self._closed.set()
187
+ _ = [t.cancel() for t in self._tasks]
188
+ await asyncio.gather(*(channel.close() for channel in self._channels))
189
+
190
+ @property
191
+ @implements(Server.stopped)
192
+ def stopped(self) -> bool:
193
+ return self._closed.is_set()
194
+
195
+
196
+ @register_client
197
+ class DummyClient(Client):
198
+ __slots__ = ("_task",)
199
+
200
+ scheme: str | None = DummyServer.scheme
201
+
202
+ def __init__(
203
+ self, local_address: str | None, dest_address: str | None, channel: Channel
204
+ ):
205
+ super().__init__(local_address, dest_address, channel)
206
+
207
+ @staticmethod
208
+ @implements(Client.connect)
209
+ async def connect(
210
+ dest_address: str, local_address: str | None = None, **kwargs
211
+ ) -> "Client":
212
+ if urlparse(dest_address).scheme != DummyServer.scheme: # pragma: no cover
213
+ raise ValueError(
214
+ f'Destination address should start with "dummy://" '
215
+ f"for DummyClient, got {dest_address}"
216
+ )
217
+ server = DummyServer.get_instance(dest_address)
218
+ if server is None: # pragma: no cover
219
+ raise RuntimeError(
220
+ f"DummyServer {dest_address} needs to be created first before DummyClient"
221
+ )
222
+ if server.stopped: # pragma: no cover
223
+ raise ConnectionError(f"Dummy server {dest_address} closed")
224
+
225
+ q1: asyncio.Queue = asyncio.Queue()
226
+ q2: asyncio.Queue = asyncio.Queue()
227
+ closed = asyncio.Event()
228
+ client_channel = DummyChannel(q1, q2, closed, local_address=local_address)
229
+ server_channel = DummyChannel(q2, q1, closed, dest_address=local_address)
230
+
231
+ conn_coro = server.on_connected(server_channel)
232
+ task = asyncio.create_task(conn_coro)
233
+ client = DummyClient(local_address, dest_address, client_channel)
234
+ client._task = task
235
+ server._tasks.append(task)
236
+ return client
237
+
238
+ @implements(Client.close)
239
+ async def close(self):
240
+ await super().close()
241
+ self._task.cancel()
242
+ self._task = None
@@ -0,0 +1,20 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from ...errors import XoscarError
17
+
18
+
19
+ class ChannelClosed(XoscarError):
20
+ pass
@@ -0,0 +1,414 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import concurrent.futures as futures
20
+ import logging
21
+ import os
22
+ import socket
23
+ import sys
24
+ from abc import ABCMeta
25
+ from asyncio import AbstractServer, StreamReader, StreamWriter
26
+ from functools import lru_cache
27
+ from hashlib import md5
28
+ from typing import Any, Callable, Coroutine, Dict, Type
29
+ from urllib.parse import urlparse
30
+
31
+ from ..._utils import to_binary
32
+ from ...constants import XOSCAR_UNIX_SOCKET_DIR
33
+ from ...serialization import AioDeserializer, AioSerializer, deserialize
34
+ from ...utils import classproperty, implements, is_py_312, is_v6_ip
35
+ from .base import Channel, ChannelType, Client, Server
36
+ from .core import register_client, register_server
37
+ from .utils import read_buffers, write_buffers
38
+
39
+ _is_windows: bool = sys.platform.startswith("win")
40
+
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ class SocketChannel(Channel):
46
+ __slots__ = "reader", "writer", "_channel_type", "_send_lock", "_recv_lock"
47
+
48
+ name = "socket"
49
+
50
+ def __init__(
51
+ self,
52
+ reader: StreamReader,
53
+ writer: StreamWriter,
54
+ local_address: str | None = None,
55
+ dest_address: str | None = None,
56
+ compression: str | None = None,
57
+ channel_type: int | None = None,
58
+ ):
59
+ super().__init__(
60
+ local_address=local_address,
61
+ dest_address=dest_address,
62
+ compression=compression,
63
+ )
64
+ self.reader = reader
65
+ self.writer = writer
66
+ self._channel_type = channel_type
67
+
68
+ self._send_lock = asyncio.Lock()
69
+ self._recv_lock = asyncio.Lock()
70
+
71
+ @property
72
+ @implements(Channel.type)
73
+ def type(self) -> int:
74
+ return self._channel_type # type: ignore
75
+
76
+ @implements(Channel.send)
77
+ async def send(self, message: Any):
78
+ # get buffers
79
+ compress = self.compression or 0
80
+ serializer = AioSerializer(message, compress=compress)
81
+ buffers = await serializer.run()
82
+
83
+ # write buffers
84
+ write_buffers(self.writer, buffers)
85
+ async with self._send_lock:
86
+ # add lock, or when parallel send,
87
+ # assertion error may be raised
88
+ await self.writer.drain()
89
+
90
+ @implements(Channel.recv)
91
+ async def recv(self):
92
+ deserializer = AioDeserializer(self.reader)
93
+ async with self._recv_lock:
94
+ header = await deserializer.get_header()
95
+ buffers = await read_buffers(header, self.reader)
96
+ return deserialize(header, buffers)
97
+
98
+ @implements(Channel.close)
99
+ async def close(self):
100
+ self.writer.close()
101
+ try:
102
+ await self.writer.wait_closed()
103
+ # TODO: May raise Runtime error: attach to another event loop
104
+ except (ConnectionResetError, RuntimeError): # pragma: no cover
105
+ pass
106
+
107
+ @property
108
+ @implements(Channel.closed)
109
+ def closed(self):
110
+ return self.writer.is_closing()
111
+
112
+
113
+ class _BaseSocketServer(Server, metaclass=ABCMeta):
114
+ __slots__ = "_aio_server", "_channels"
115
+
116
+ _channels: list[ChannelType]
117
+
118
+ def __init__(
119
+ self,
120
+ address: str,
121
+ aio_server: AbstractServer,
122
+ channel_handler: Callable[[Channel], Coroutine] | None = None,
123
+ ):
124
+ super().__init__(address, channel_handler)
125
+ # asyncio.Server
126
+ self._aio_server = aio_server
127
+ self._channels = []
128
+
129
+ @implements(Server.start)
130
+ async def start(self):
131
+ await self._aio_server.start_serving()
132
+
133
+ @implements(Server.join)
134
+ async def join(self, timeout=None):
135
+ if timeout is None:
136
+ await self._aio_server.serve_forever()
137
+ else:
138
+ if is_py_312():
139
+ # For python 3.12, there's a bug for `serve_forever`:
140
+ # https://github.com/python/cpython/issues/123720,
141
+ # which is unable to be cancelled.
142
+ # Here is really a simulation of `wait_for`
143
+ task = asyncio.create_task(self._aio_server.serve_forever())
144
+ await asyncio.sleep(timeout)
145
+ if task.done():
146
+ logger.warning(f"`serve_forever` should never be done.")
147
+ else:
148
+ task.cancel()
149
+ else:
150
+ future = asyncio.create_task(self._aio_server.serve_forever())
151
+ try:
152
+ await asyncio.wait_for(future, timeout=timeout)
153
+ except (futures.TimeoutError, asyncio.TimeoutError, TimeoutError):
154
+ future.cancel()
155
+
156
+ @implements(Server.on_connected)
157
+ async def on_connected(self, *args, **kwargs):
158
+ reader, writer = args
159
+ local_address = kwargs.pop("local_address", None)
160
+ dest_address = kwargs.pop("dest_address", None)
161
+ if kwargs: # pragma: no cover
162
+ raise TypeError(
163
+ f"{type(self).__name__} got unexpected "
164
+ f'arguments: {",".join(kwargs)}'
165
+ )
166
+ channel = SocketChannel(
167
+ reader,
168
+ writer,
169
+ local_address=local_address,
170
+ dest_address=dest_address,
171
+ channel_type=self.channel_type,
172
+ )
173
+ self._channels.append(channel)
174
+ # handle over channel to some handlers
175
+ await self.channel_handler(channel)
176
+
177
+ @implements(Server.stop)
178
+ async def stop(self):
179
+ self._aio_server.close()
180
+ # Python 3.12: # https://github.com/python/cpython/issues/104344
181
+ # `wait_closed` leads to hang
182
+ if not is_py_312():
183
+ await self._aio_server.wait_closed()
184
+ # close all channels
185
+ await asyncio.gather(
186
+ *(channel.close() for channel in self._channels if not channel.closed)
187
+ )
188
+
189
+ @property
190
+ @implements(Server.stopped)
191
+ def stopped(self) -> bool:
192
+ return not self._aio_server.is_serving()
193
+
194
+
195
+ @register_server
196
+ class SocketServer(_BaseSocketServer):
197
+ __slots__ = "host", "port"
198
+
199
+ scheme = None
200
+
201
+ def __init__(
202
+ self,
203
+ host: str,
204
+ port: int,
205
+ aio_server: AbstractServer,
206
+ channel_handler: Callable[[Channel], Coroutine] | None = None,
207
+ ):
208
+ address = f"{host}:{port}"
209
+ super().__init__(address, aio_server, channel_handler=channel_handler)
210
+ self.host = host
211
+ self.port = port
212
+
213
+ @classproperty
214
+ @implements(Server.client_type)
215
+ def client_type(self) -> Type["Client"]:
216
+ return SocketClient
217
+
218
+ @property
219
+ @implements(Server.channel_type)
220
+ def channel_type(self) -> int:
221
+ return ChannelType.remote
222
+
223
+ @classmethod
224
+ def parse_config(cls, config: dict) -> dict:
225
+ if config is None or not config:
226
+ return dict()
227
+ # we only need the following config
228
+ keys = ["listen_elastic_ip"]
229
+ parsed_config = {key: config[key] for key in keys if key in config}
230
+
231
+ return parsed_config
232
+
233
+ @staticmethod
234
+ @implements(Server.create)
235
+ async def create(config: Dict) -> "Server":
236
+ config = config.copy()
237
+ if "address" in config:
238
+ address = config.pop("address")
239
+ host, port = address.rsplit(":", 1)
240
+ port = int(port)
241
+ else:
242
+ host = config.pop("host")
243
+ port = int(config.pop("port"))
244
+ _host = host
245
+ if config.pop("listen_elastic_ip", False):
246
+ # The Actor.address will be announce to client, and is not on our host,
247
+ # cannot actually listen on it,
248
+ # so we have to keep SocketServer.host untouched to make sure Actor.address not changed
249
+ if is_v6_ip(host):
250
+ _host = "::"
251
+ else:
252
+ _host = "0.0.0.0"
253
+
254
+ handle_channel = config.pop("handle_channel")
255
+ if "start_serving" not in config:
256
+ config["start_serving"] = False
257
+
258
+ async def handle_connection(reader: StreamReader, writer: StreamWriter):
259
+ # create a channel when client connected
260
+ return await server.on_connected(
261
+ reader, writer, local_address=server.address
262
+ )
263
+
264
+ port = port if port != 0 else None
265
+ aio_server = await asyncio.start_server(
266
+ handle_connection, host=_host, port=port, **config
267
+ )
268
+
269
+ # get port of the socket if not specified
270
+ if not port:
271
+ port = aio_server.sockets[0].getsockname()[1]
272
+
273
+ if _is_windows:
274
+ for sock in aio_server.sockets:
275
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
276
+
277
+ server = SocketServer(host, port, aio_server, channel_handler=handle_channel)
278
+ return server
279
+
280
+
281
+ @register_client
282
+ class SocketClient(Client):
283
+ __slots__ = ()
284
+
285
+ scheme = SocketServer.scheme
286
+
287
+ @staticmethod
288
+ @implements(Client.connect)
289
+ async def connect(
290
+ dest_address: str, local_address: str | None = None, **kwargs
291
+ ) -> "Client":
292
+ host, port_str = dest_address.rsplit(":", 1)
293
+ port = int(port_str)
294
+ (reader, writer) = await asyncio.open_connection(host=host, port=port, **kwargs)
295
+ channel = SocketChannel(
296
+ reader, writer, local_address=local_address, dest_address=dest_address
297
+ )
298
+ return SocketClient(local_address, dest_address, channel)
299
+
300
+
301
+ def _get_or_create_default_unix_socket_dir():
302
+ os.makedirs(XOSCAR_UNIX_SOCKET_DIR, exist_ok=True)
303
+ return XOSCAR_UNIX_SOCKET_DIR
304
+
305
+
306
+ @lru_cache(100)
307
+ def _gen_unix_socket_default_path(process_index):
308
+ return (
309
+ f"{_get_or_create_default_unix_socket_dir()}/"
310
+ f"{md5(to_binary(str(process_index))).hexdigest()}"
311
+ ) # nosec
312
+
313
+
314
+ @register_server
315
+ class UnixSocketServer(_BaseSocketServer):
316
+ __slots__ = "process_index", "path"
317
+
318
+ scheme = "unixsocket"
319
+
320
+ def __init__(
321
+ self,
322
+ process_index: int,
323
+ aio_server: AbstractServer,
324
+ path: str,
325
+ channel_handler: Callable[[Channel], Coroutine] | None = None,
326
+ ):
327
+ address = f"{self.scheme}:///{process_index}"
328
+ super().__init__(address, aio_server, channel_handler=channel_handler)
329
+ self.process_index = process_index
330
+ self.path = path
331
+
332
+ @classproperty
333
+ @implements(Server.client_type)
334
+ def client_type(self) -> Type["Client"]:
335
+ return UnixSocketClient
336
+
337
+ @property
338
+ @implements(Server.channel_type)
339
+ def channel_type(self) -> int:
340
+ return ChannelType.ipc
341
+
342
+ @staticmethod
343
+ @implements(Server.create)
344
+ async def create(config: Dict) -> "Server":
345
+ config = config.copy()
346
+ if "address" in config:
347
+ process_index = int(urlparse(config.pop("address")).path.lstrip("/"))
348
+ else:
349
+ process_index = config.pop("process_index")
350
+ handle_channel = config.pop("handle_channel")
351
+ path = config.pop("path", _gen_unix_socket_default_path(process_index))
352
+
353
+ dirname = os.path.dirname(path)
354
+ if not os.path.exists(dirname):
355
+ os.makedirs(dirname, exist_ok=True)
356
+
357
+ if "start_serving" not in config:
358
+ config["start_serving"] = False
359
+
360
+ async def handle_connection(reader, writer):
361
+ # create a channel when client connected
362
+ return await server.on_connected(
363
+ reader, writer, local_address=server.address
364
+ )
365
+
366
+ aio_server = await asyncio.start_unix_server(
367
+ handle_connection, path=path, **config
368
+ )
369
+
370
+ for sock in aio_server.sockets:
371
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
372
+
373
+ server = UnixSocketServer(
374
+ process_index, aio_server, path, channel_handler=handle_channel
375
+ )
376
+ return server
377
+
378
+ @implements(Server.stop)
379
+ async def stop(self):
380
+ await super().stop()
381
+ try:
382
+ os.remove(self.path)
383
+ except OSError: # pragma: no cover
384
+ pass
385
+
386
+
387
+ @register_client
388
+ class UnixSocketClient(Client):
389
+ __slots__ = ()
390
+
391
+ scheme = UnixSocketServer.scheme
392
+
393
+ @staticmethod
394
+ @lru_cache(100)
395
+ def _get_process_index(addr):
396
+ return int(urlparse(addr).path.lstrip("/"))
397
+
398
+ @staticmethod
399
+ @implements(Client.connect)
400
+ async def connect(
401
+ dest_address: str, local_address: str | None = None, **kwargs
402
+ ) -> "Client":
403
+ process_index = UnixSocketClient._get_process_index(dest_address)
404
+ path = kwargs.pop("path", _gen_unix_socket_default_path(process_index))
405
+ try:
406
+ (reader, writer) = await asyncio.open_unix_connection(path, **kwargs)
407
+ except FileNotFoundError:
408
+ raise ConnectionRefusedError(
409
+ "Cannot connect unix socket due to file not exists"
410
+ )
411
+ channel = SocketChannel(
412
+ reader, writer, local_address=local_address, dest_address=dest_address
413
+ )
414
+ return UnixSocketClient(local_address, dest_address, channel)