xoscar 0.3.1__cp38-cp38-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xoscar might be problematic. Click here for more details.

Files changed (80) hide show
  1. xoscar/__init__.py +60 -0
  2. xoscar/_utils.cpython-38-darwin.so +0 -0
  3. xoscar/_utils.pxd +36 -0
  4. xoscar/_utils.pyx +241 -0
  5. xoscar/_version.py +693 -0
  6. xoscar/aio/__init__.py +25 -0
  7. xoscar/aio/_threads.py +35 -0
  8. xoscar/aio/base.py +86 -0
  9. xoscar/aio/file.py +59 -0
  10. xoscar/aio/lru.py +228 -0
  11. xoscar/aio/parallelism.py +39 -0
  12. xoscar/api.py +493 -0
  13. xoscar/backend.py +67 -0
  14. xoscar/backends/__init__.py +14 -0
  15. xoscar/backends/allocate_strategy.py +160 -0
  16. xoscar/backends/communication/__init__.py +30 -0
  17. xoscar/backends/communication/base.py +315 -0
  18. xoscar/backends/communication/core.py +69 -0
  19. xoscar/backends/communication/dummy.py +242 -0
  20. xoscar/backends/communication/errors.py +20 -0
  21. xoscar/backends/communication/socket.py +375 -0
  22. xoscar/backends/communication/ucx.py +520 -0
  23. xoscar/backends/communication/utils.py +97 -0
  24. xoscar/backends/config.py +145 -0
  25. xoscar/backends/context.py +404 -0
  26. xoscar/backends/core.py +193 -0
  27. xoscar/backends/indigen/__init__.py +16 -0
  28. xoscar/backends/indigen/backend.py +51 -0
  29. xoscar/backends/indigen/driver.py +26 -0
  30. xoscar/backends/indigen/pool.py +469 -0
  31. xoscar/backends/message.cpython-38-darwin.so +0 -0
  32. xoscar/backends/message.pyx +591 -0
  33. xoscar/backends/pool.py +1593 -0
  34. xoscar/backends/router.py +207 -0
  35. xoscar/backends/test/__init__.py +16 -0
  36. xoscar/backends/test/backend.py +38 -0
  37. xoscar/backends/test/pool.py +208 -0
  38. xoscar/batch.py +256 -0
  39. xoscar/collective/__init__.py +27 -0
  40. xoscar/collective/common.py +102 -0
  41. xoscar/collective/core.py +737 -0
  42. xoscar/collective/process_group.py +687 -0
  43. xoscar/collective/utils.py +41 -0
  44. xoscar/collective/xoscar_pygloo.cpython-38-darwin.so +0 -0
  45. xoscar/constants.py +21 -0
  46. xoscar/context.cpython-38-darwin.so +0 -0
  47. xoscar/context.pxd +21 -0
  48. xoscar/context.pyx +368 -0
  49. xoscar/core.cpython-38-darwin.so +0 -0
  50. xoscar/core.pxd +50 -0
  51. xoscar/core.pyx +658 -0
  52. xoscar/debug.py +188 -0
  53. xoscar/driver.py +42 -0
  54. xoscar/errors.py +63 -0
  55. xoscar/libcpp.pxd +31 -0
  56. xoscar/metrics/__init__.py +21 -0
  57. xoscar/metrics/api.py +288 -0
  58. xoscar/metrics/backends/__init__.py +13 -0
  59. xoscar/metrics/backends/console/__init__.py +13 -0
  60. xoscar/metrics/backends/console/console_metric.py +82 -0
  61. xoscar/metrics/backends/metric.py +149 -0
  62. xoscar/metrics/backends/prometheus/__init__.py +13 -0
  63. xoscar/metrics/backends/prometheus/prometheus_metric.py +70 -0
  64. xoscar/nvutils.py +717 -0
  65. xoscar/profiling.py +260 -0
  66. xoscar/serialization/__init__.py +20 -0
  67. xoscar/serialization/aio.py +138 -0
  68. xoscar/serialization/core.cpython-38-darwin.so +0 -0
  69. xoscar/serialization/core.pxd +28 -0
  70. xoscar/serialization/core.pyx +954 -0
  71. xoscar/serialization/cuda.py +111 -0
  72. xoscar/serialization/exception.py +48 -0
  73. xoscar/serialization/numpy.py +82 -0
  74. xoscar/serialization/pyfury.py +37 -0
  75. xoscar/serialization/scipy.py +72 -0
  76. xoscar/utils.py +502 -0
  77. xoscar-0.3.1.dist-info/METADATA +225 -0
  78. xoscar-0.3.1.dist-info/RECORD +80 -0
  79. xoscar-0.3.1.dist-info/WHEEL +5 -0
  80. xoscar-0.3.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,242 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import concurrent.futures as futures
20
+ import weakref
21
+ from typing import Any, Callable, Coroutine, Dict, Type
22
+ from urllib.parse import urlparse
23
+
24
+ from ...errors import ServerClosed
25
+ from ...utils import abc_type_require_weakref_slot, classproperty, implements
26
+ from .base import Channel, ChannelType, Client, Server
27
+ from .core import register_client, register_server
28
+ from .errors import ChannelClosed
29
+
30
+ DEFAULT_DUMMY_ADDRESS = "dummy://0"
31
+
32
+
33
+ class DummyChannel(Channel):
34
+ """
35
+ Channel for communications in same process.
36
+ """
37
+
38
+ __slots__ = "_in_queue", "_out_queue", "_closed"
39
+
40
+ name = "dummy"
41
+
42
+ def __init__(
43
+ self,
44
+ in_queue: asyncio.Queue,
45
+ out_queue: asyncio.Queue,
46
+ closed: asyncio.Event,
47
+ local_address: str | None = None,
48
+ dest_address: str | None = None,
49
+ compression: str | None = None,
50
+ ):
51
+ super().__init__(
52
+ local_address=local_address,
53
+ dest_address=dest_address,
54
+ compression=compression,
55
+ )
56
+ self._in_queue = in_queue
57
+ self._out_queue = out_queue
58
+ self._closed = closed
59
+
60
+ @property
61
+ @implements(Channel.type)
62
+ def type(self) -> int:
63
+ return ChannelType.local
64
+
65
+ @implements(Channel.send)
66
+ async def send(self, message: Any):
67
+ if self._closed.is_set(): # pragma: no cover
68
+ raise ChannelClosed("Channel already closed, cannot send message")
69
+ # put message directly into queue
70
+ self._out_queue.put_nowait(message)
71
+
72
+ @implements(Channel.recv)
73
+ async def recv(self):
74
+ if self._closed.is_set(): # pragma: no cover
75
+ raise ChannelClosed("Channel already closed, cannot write message")
76
+ try:
77
+ return await self._in_queue.get()
78
+ except RuntimeError:
79
+ if self._closed.is_set():
80
+ pass
81
+
82
+ @implements(Channel.close)
83
+ async def close(self):
84
+ self._closed.set()
85
+
86
+ @property
87
+ @implements(Channel.closed)
88
+ def closed(self) -> bool:
89
+ return self._closed.is_set()
90
+
91
+
92
+ @register_server
93
+ class DummyServer(Server):
94
+ __slots__ = (
95
+ ("_closed", "_channels", "_tasks") + ("__weakref__",)
96
+ if abc_type_require_weakref_slot
97
+ else tuple()
98
+ )
99
+
100
+ _address_to_instances: weakref.WeakValueDictionary[str, "DummyServer"] = (
101
+ weakref.WeakValueDictionary()
102
+ )
103
+ _channels: list[ChannelType]
104
+ _tasks: list[asyncio.Task]
105
+ scheme: str | None = "dummy"
106
+
107
+ def __init__(
108
+ self,
109
+ address: str,
110
+ channel_handler: Callable[[Channel], Coroutine] | None = None,
111
+ ):
112
+ super().__init__(address, channel_handler)
113
+ self._closed = asyncio.Event()
114
+ self._channels = []
115
+ self._tasks = []
116
+
117
+ @classmethod
118
+ def get_instance(cls, address: str):
119
+ return cls._address_to_instances[address]
120
+
121
+ @classproperty
122
+ @implements(Server.client_type)
123
+ def client_type(self) -> Type["Client"]:
124
+ return DummyClient
125
+
126
+ @property
127
+ @implements(Server.channel_type)
128
+ def channel_type(self) -> int:
129
+ return ChannelType.local
130
+
131
+ @staticmethod
132
+ @implements(Server.create)
133
+ async def create(config: Dict) -> "DummyServer":
134
+ config = config.copy()
135
+ address = config.pop("address", DEFAULT_DUMMY_ADDRESS)
136
+ handle_channel = config.pop("handle_channel")
137
+ if urlparse(address).scheme != DummyServer.scheme: # pragma: no cover
138
+ raise ValueError(
139
+ f"Address for DummyServer "
140
+ f'should be starts with "dummy://", '
141
+ f"got {address}"
142
+ )
143
+ if config: # pragma: no cover
144
+ raise TypeError(
145
+ f"Creating DummyServer got unexpected " f'arguments: {",".join(config)}'
146
+ )
147
+ try:
148
+ server = DummyServer.get_instance(address)
149
+ if server.stopped:
150
+ raise KeyError("server closed")
151
+ except KeyError:
152
+ server = DummyServer(address, handle_channel)
153
+ DummyServer._address_to_instances[address] = server
154
+ return server
155
+
156
+ @implements(Server.start)
157
+ async def start(self):
158
+ # nothing needs to do for dummy server
159
+ pass
160
+
161
+ @implements(Server.join)
162
+ async def join(self, timeout=None):
163
+ wait_coro = self._closed.wait()
164
+ try:
165
+ await asyncio.wait_for(wait_coro, timeout=timeout)
166
+ except (futures.TimeoutError, asyncio.TimeoutError):
167
+ pass
168
+
169
+ @implements(Server.on_connected)
170
+ async def on_connected(self, *args, **kwargs):
171
+ if self._closed.is_set(): # pragma: no cover
172
+ raise ServerClosed("Dummy server already closed")
173
+
174
+ channel = args[0]
175
+ assert isinstance(channel, DummyChannel)
176
+ if kwargs: # pragma: no cover
177
+ raise TypeError(
178
+ f"{type(self).__name__} got unexpected "
179
+ f'arguments: {",".join(kwargs)}'
180
+ )
181
+ self._channels.append(channel)
182
+ await self.channel_handler(channel)
183
+
184
+ @implements(Server.stop)
185
+ async def stop(self):
186
+ self._closed.set()
187
+ _ = [t.cancel() for t in self._tasks]
188
+ await asyncio.gather(*(channel.close() for channel in self._channels))
189
+
190
+ @property
191
+ @implements(Server.stopped)
192
+ def stopped(self) -> bool:
193
+ return self._closed.is_set()
194
+
195
+
196
+ @register_client
197
+ class DummyClient(Client):
198
+ __slots__ = ("_task",)
199
+
200
+ scheme: str | None = DummyServer.scheme
201
+
202
+ def __init__(
203
+ self, local_address: str | None, dest_address: str | None, channel: Channel
204
+ ):
205
+ super().__init__(local_address, dest_address, channel)
206
+
207
+ @staticmethod
208
+ @implements(Client.connect)
209
+ async def connect(
210
+ dest_address: str, local_address: str | None = None, **kwargs
211
+ ) -> "Client":
212
+ if urlparse(dest_address).scheme != DummyServer.scheme: # pragma: no cover
213
+ raise ValueError(
214
+ f'Destination address should start with "dummy://" '
215
+ f"for DummyClient, got {dest_address}"
216
+ )
217
+ server = DummyServer.get_instance(dest_address)
218
+ if server is None: # pragma: no cover
219
+ raise RuntimeError(
220
+ f"DummyServer {dest_address} needs to be created first before DummyClient"
221
+ )
222
+ if server.stopped: # pragma: no cover
223
+ raise ConnectionError(f"Dummy server {dest_address} closed")
224
+
225
+ q1: asyncio.Queue = asyncio.Queue()
226
+ q2: asyncio.Queue = asyncio.Queue()
227
+ closed = asyncio.Event()
228
+ client_channel = DummyChannel(q1, q2, closed, local_address=local_address)
229
+ server_channel = DummyChannel(q2, q1, closed, dest_address=local_address)
230
+
231
+ conn_coro = server.on_connected(server_channel)
232
+ task = asyncio.create_task(conn_coro)
233
+ client = DummyClient(local_address, dest_address, client_channel)
234
+ client._task = task
235
+ server._tasks.append(task)
236
+ return client
237
+
238
+ @implements(Client.close)
239
+ async def close(self):
240
+ await super().close()
241
+ self._task.cancel()
242
+ self._task = None
@@ -0,0 +1,20 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from ...errors import XoscarError
17
+
18
+
19
+ class ChannelClosed(XoscarError):
20
+ pass
@@ -0,0 +1,375 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import concurrent.futures as futures
20
+ import os
21
+ import socket
22
+ import sys
23
+ from abc import ABCMeta
24
+ from asyncio import AbstractServer, StreamReader, StreamWriter
25
+ from functools import lru_cache
26
+ from hashlib import md5
27
+ from typing import Any, Callable, Coroutine, Dict, Type
28
+ from urllib.parse import urlparse
29
+
30
+ from ..._utils import to_binary
31
+ from ...constants import XOSCAR_UNIX_SOCKET_DIR
32
+ from ...serialization import AioDeserializer, AioSerializer, deserialize
33
+ from ...utils import classproperty, implements
34
+ from .base import Channel, ChannelType, Client, Server
35
+ from .core import register_client, register_server
36
+ from .utils import read_buffers, write_buffers
37
+
38
+ _is_windows: bool = sys.platform.startswith("win")
39
+
40
+
41
+ class SocketChannel(Channel):
42
+ __slots__ = "reader", "writer", "_channel_type", "_send_lock", "_recv_lock"
43
+
44
+ name = "socket"
45
+
46
+ def __init__(
47
+ self,
48
+ reader: StreamReader,
49
+ writer: StreamWriter,
50
+ local_address: str | None = None,
51
+ dest_address: str | None = None,
52
+ compression: str | None = None,
53
+ channel_type: int | None = None,
54
+ ):
55
+ super().__init__(
56
+ local_address=local_address,
57
+ dest_address=dest_address,
58
+ compression=compression,
59
+ )
60
+ self.reader = reader
61
+ self.writer = writer
62
+ self._channel_type = channel_type
63
+
64
+ self._send_lock = asyncio.Lock()
65
+ self._recv_lock = asyncio.Lock()
66
+
67
+ @property
68
+ @implements(Channel.type)
69
+ def type(self) -> int:
70
+ return self._channel_type # type: ignore
71
+
72
+ @implements(Channel.send)
73
+ async def send(self, message: Any):
74
+ # get buffers
75
+ compress = self.compression or 0
76
+ serializer = AioSerializer(message, compress=compress)
77
+ buffers = await serializer.run()
78
+
79
+ # write buffers
80
+ write_buffers(self.writer, buffers)
81
+ async with self._send_lock:
82
+ # add lock, or when parallel send,
83
+ # assertion error may be raised
84
+ await self.writer.drain()
85
+
86
+ @implements(Channel.recv)
87
+ async def recv(self):
88
+ deserializer = AioDeserializer(self.reader)
89
+ async with self._recv_lock:
90
+ header = await deserializer.get_header()
91
+ buffers = await read_buffers(header, self.reader)
92
+ return deserialize(header, buffers)
93
+
94
+ @implements(Channel.close)
95
+ async def close(self):
96
+ self.writer.close()
97
+ try:
98
+ await self.writer.wait_closed()
99
+ # TODO: May raise Runtime error: attach to another event loop
100
+ except (ConnectionResetError, RuntimeError): # pragma: no cover
101
+ pass
102
+
103
+ @property
104
+ @implements(Channel.closed)
105
+ def closed(self):
106
+ return self.writer.is_closing()
107
+
108
+
109
+ class _BaseSocketServer(Server, metaclass=ABCMeta):
110
+ __slots__ = "_aio_server", "_channels"
111
+
112
+ _channels: list[ChannelType]
113
+
114
+ def __init__(
115
+ self,
116
+ address: str,
117
+ aio_server: AbstractServer,
118
+ channel_handler: Callable[[Channel], Coroutine] | None = None,
119
+ ):
120
+ super().__init__(address, channel_handler)
121
+ # asyncio.Server
122
+ self._aio_server = aio_server
123
+ self._channels = []
124
+
125
+ @implements(Server.start)
126
+ async def start(self):
127
+ await self._aio_server.start_serving()
128
+
129
+ @implements(Server.join)
130
+ async def join(self, timeout=None):
131
+ if timeout is None:
132
+ await self._aio_server.serve_forever()
133
+ else:
134
+ future = asyncio.create_task(self._aio_server.serve_forever())
135
+ try:
136
+ await asyncio.wait_for(future, timeout=timeout)
137
+ except (futures.TimeoutError, asyncio.TimeoutError):
138
+ future.cancel()
139
+
140
+ @implements(Server.on_connected)
141
+ async def on_connected(self, *args, **kwargs):
142
+ reader, writer = args
143
+ local_address = kwargs.pop("local_address", None)
144
+ dest_address = kwargs.pop("dest_address", None)
145
+ if kwargs: # pragma: no cover
146
+ raise TypeError(
147
+ f"{type(self).__name__} got unexpected "
148
+ f'arguments: {",".join(kwargs)}'
149
+ )
150
+ channel = SocketChannel(
151
+ reader,
152
+ writer,
153
+ local_address=local_address,
154
+ dest_address=dest_address,
155
+ channel_type=self.channel_type,
156
+ )
157
+ self._channels.append(channel)
158
+ # handle over channel to some handlers
159
+ await self.channel_handler(channel)
160
+
161
+ @implements(Server.stop)
162
+ async def stop(self):
163
+ self._aio_server.close()
164
+ await self._aio_server.wait_closed()
165
+ # close all channels
166
+ await asyncio.gather(
167
+ *(channel.close() for channel in self._channels if not channel.closed)
168
+ )
169
+
170
+ @property
171
+ @implements(Server.stopped)
172
+ def stopped(self) -> bool:
173
+ return not self._aio_server.is_serving()
174
+
175
+
176
+ @register_server
177
+ class SocketServer(_BaseSocketServer):
178
+ __slots__ = "host", "port"
179
+
180
+ scheme = None
181
+
182
+ def __init__(
183
+ self,
184
+ host: str,
185
+ port: int,
186
+ aio_server: AbstractServer,
187
+ channel_handler: Callable[[Channel], Coroutine] | None = None,
188
+ ):
189
+ address = f"{host}:{port}"
190
+ super().__init__(address, aio_server, channel_handler=channel_handler)
191
+ self.host = host
192
+ self.port = port
193
+
194
+ @classproperty
195
+ @implements(Server.client_type)
196
+ def client_type(self) -> Type["Client"]:
197
+ return SocketClient
198
+
199
+ @property
200
+ @implements(Server.channel_type)
201
+ def channel_type(self) -> int:
202
+ return ChannelType.remote
203
+
204
+ @staticmethod
205
+ @implements(Server.create)
206
+ async def create(config: Dict) -> "Server":
207
+ config = config.copy()
208
+ if "address" in config:
209
+ address = config.pop("address")
210
+ host, port = address.split(":", 1)
211
+ port = int(port)
212
+ else:
213
+ host = config.pop("host")
214
+ port = int(config.pop("port"))
215
+ handle_channel = config.pop("handle_channel")
216
+ if "start_serving" not in config:
217
+ config["start_serving"] = False
218
+
219
+ async def handle_connection(reader: StreamReader, writer: StreamWriter):
220
+ # create a channel when client connected
221
+ return await server.on_connected(
222
+ reader, writer, local_address=server.address
223
+ )
224
+
225
+ port = port if port != 0 else None
226
+ aio_server = await asyncio.start_server(
227
+ handle_connection, host=host, port=port, **config
228
+ )
229
+
230
+ # get port of the socket if not specified
231
+ if not port:
232
+ port = aio_server.sockets[0].getsockname()[1]
233
+
234
+ if _is_windows:
235
+ for sock in aio_server.sockets:
236
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
237
+
238
+ server = SocketServer(host, port, aio_server, channel_handler=handle_channel)
239
+ return server
240
+
241
+
242
+ @register_client
243
+ class SocketClient(Client):
244
+ __slots__ = ()
245
+
246
+ scheme = SocketServer.scheme
247
+
248
+ @staticmethod
249
+ @implements(Client.connect)
250
+ async def connect(
251
+ dest_address: str, local_address: str | None = None, **kwargs
252
+ ) -> "Client":
253
+ host, port_str = dest_address.split(":", 1)
254
+ port = int(port_str)
255
+ (reader, writer) = await asyncio.open_connection(host=host, port=port, **kwargs)
256
+ channel = SocketChannel(
257
+ reader, writer, local_address=local_address, dest_address=dest_address
258
+ )
259
+ return SocketClient(local_address, dest_address, channel)
260
+
261
+
262
+ def _get_or_create_default_unix_socket_dir():
263
+ os.makedirs(XOSCAR_UNIX_SOCKET_DIR, exist_ok=True)
264
+ return XOSCAR_UNIX_SOCKET_DIR
265
+
266
+
267
+ @lru_cache(100)
268
+ def _gen_unix_socket_default_path(process_index):
269
+ return (
270
+ f"{_get_or_create_default_unix_socket_dir()}/"
271
+ f"{md5(to_binary(str(process_index))).hexdigest()}"
272
+ ) # nosec
273
+
274
+
275
+ @register_server
276
+ class UnixSocketServer(_BaseSocketServer):
277
+ __slots__ = "process_index", "path"
278
+
279
+ scheme = "unixsocket"
280
+
281
+ def __init__(
282
+ self,
283
+ process_index: int,
284
+ aio_server: AbstractServer,
285
+ path: str,
286
+ channel_handler: Callable[[Channel], Coroutine] | None = None,
287
+ ):
288
+ address = f"{self.scheme}:///{process_index}"
289
+ super().__init__(address, aio_server, channel_handler=channel_handler)
290
+ self.process_index = process_index
291
+ self.path = path
292
+
293
+ @classproperty
294
+ @implements(Server.client_type)
295
+ def client_type(self) -> Type["Client"]:
296
+ return UnixSocketClient
297
+
298
+ @property
299
+ @implements(Server.channel_type)
300
+ def channel_type(self) -> int:
301
+ return ChannelType.ipc
302
+
303
+ @staticmethod
304
+ @implements(Server.create)
305
+ async def create(config: Dict) -> "Server":
306
+ config = config.copy()
307
+ if "address" in config:
308
+ process_index = int(urlparse(config.pop("address")).path.lstrip("/"))
309
+ else:
310
+ process_index = config.pop("process_index")
311
+ handle_channel = config.pop("handle_channel")
312
+ path = config.pop("path", _gen_unix_socket_default_path(process_index))
313
+
314
+ dirname = os.path.dirname(path)
315
+ if not os.path.exists(dirname):
316
+ os.makedirs(dirname, exist_ok=True)
317
+
318
+ if "start_serving" not in config:
319
+ config["start_serving"] = False
320
+
321
+ async def handle_connection(reader, writer):
322
+ # create a channel when client connected
323
+ return await server.on_connected(
324
+ reader, writer, local_address=server.address
325
+ )
326
+
327
+ aio_server = await asyncio.start_unix_server(
328
+ handle_connection, path=path, **config
329
+ )
330
+
331
+ for sock in aio_server.sockets:
332
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
333
+
334
+ server = UnixSocketServer(
335
+ process_index, aio_server, path, channel_handler=handle_channel
336
+ )
337
+ return server
338
+
339
+ @implements(Server.stop)
340
+ async def stop(self):
341
+ await super().stop()
342
+ try:
343
+ os.remove(self.path)
344
+ except OSError: # pragma: no cover
345
+ pass
346
+
347
+
348
+ @register_client
349
+ class UnixSocketClient(Client):
350
+ __slots__ = ()
351
+
352
+ scheme = UnixSocketServer.scheme
353
+
354
+ @staticmethod
355
+ @lru_cache(100)
356
+ def _get_process_index(addr):
357
+ return int(urlparse(addr).path.lstrip("/"))
358
+
359
+ @staticmethod
360
+ @implements(Client.connect)
361
+ async def connect(
362
+ dest_address: str, local_address: str | None = None, **kwargs
363
+ ) -> "Client":
364
+ process_index = UnixSocketClient._get_process_index(dest_address)
365
+ path = kwargs.pop("path", _gen_unix_socket_default_path(process_index))
366
+ try:
367
+ (reader, writer) = await asyncio.open_unix_connection(path, **kwargs)
368
+ except FileNotFoundError:
369
+ raise ConnectionRefusedError(
370
+ "Cannot connect unix socket due to file not exists"
371
+ )
372
+ channel = SocketChannel(
373
+ reader, writer, local_address=local_address, dest_address=dest_address
374
+ )
375
+ return UnixSocketClient(local_address, dest_address, channel)