xoscar 0.9.0__cp312-cp312-macosx_10_13_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. xoscar/__init__.py +61 -0
  2. xoscar/_utils.cpython-312-darwin.so +0 -0
  3. xoscar/_utils.pxd +36 -0
  4. xoscar/_utils.pyx +246 -0
  5. xoscar/_version.py +693 -0
  6. xoscar/aio/__init__.py +16 -0
  7. xoscar/aio/base.py +86 -0
  8. xoscar/aio/file.py +59 -0
  9. xoscar/aio/lru.py +228 -0
  10. xoscar/aio/parallelism.py +39 -0
  11. xoscar/api.py +527 -0
  12. xoscar/backend.py +67 -0
  13. xoscar/backends/__init__.py +14 -0
  14. xoscar/backends/allocate_strategy.py +160 -0
  15. xoscar/backends/communication/__init__.py +30 -0
  16. xoscar/backends/communication/base.py +315 -0
  17. xoscar/backends/communication/core.py +69 -0
  18. xoscar/backends/communication/dummy.py +253 -0
  19. xoscar/backends/communication/errors.py +20 -0
  20. xoscar/backends/communication/socket.py +444 -0
  21. xoscar/backends/communication/ucx.py +538 -0
  22. xoscar/backends/communication/utils.py +97 -0
  23. xoscar/backends/config.py +157 -0
  24. xoscar/backends/context.py +437 -0
  25. xoscar/backends/core.py +352 -0
  26. xoscar/backends/indigen/__init__.py +16 -0
  27. xoscar/backends/indigen/__main__.py +19 -0
  28. xoscar/backends/indigen/backend.py +51 -0
  29. xoscar/backends/indigen/driver.py +26 -0
  30. xoscar/backends/indigen/fate_sharing.py +221 -0
  31. xoscar/backends/indigen/pool.py +515 -0
  32. xoscar/backends/indigen/shared_memory.py +548 -0
  33. xoscar/backends/message.cpython-312-darwin.so +0 -0
  34. xoscar/backends/message.pyi +255 -0
  35. xoscar/backends/message.pyx +646 -0
  36. xoscar/backends/pool.py +1630 -0
  37. xoscar/backends/router.py +285 -0
  38. xoscar/backends/test/__init__.py +16 -0
  39. xoscar/backends/test/backend.py +38 -0
  40. xoscar/backends/test/pool.py +233 -0
  41. xoscar/batch.py +256 -0
  42. xoscar/collective/__init__.py +27 -0
  43. xoscar/collective/backend/__init__.py +13 -0
  44. xoscar/collective/backend/nccl_backend.py +160 -0
  45. xoscar/collective/common.py +102 -0
  46. xoscar/collective/core.py +737 -0
  47. xoscar/collective/process_group.py +687 -0
  48. xoscar/collective/utils.py +41 -0
  49. xoscar/collective/xoscar_pygloo.cpython-312-darwin.so +0 -0
  50. xoscar/collective/xoscar_pygloo.pyi +239 -0
  51. xoscar/constants.py +23 -0
  52. xoscar/context.cpython-312-darwin.so +0 -0
  53. xoscar/context.pxd +21 -0
  54. xoscar/context.pyx +368 -0
  55. xoscar/core.cpython-312-darwin.so +0 -0
  56. xoscar/core.pxd +51 -0
  57. xoscar/core.pyx +664 -0
  58. xoscar/debug.py +188 -0
  59. xoscar/driver.py +42 -0
  60. xoscar/errors.py +63 -0
  61. xoscar/libcpp.pxd +31 -0
  62. xoscar/metrics/__init__.py +21 -0
  63. xoscar/metrics/api.py +288 -0
  64. xoscar/metrics/backends/__init__.py +13 -0
  65. xoscar/metrics/backends/console/__init__.py +13 -0
  66. xoscar/metrics/backends/console/console_metric.py +82 -0
  67. xoscar/metrics/backends/metric.py +149 -0
  68. xoscar/metrics/backends/prometheus/__init__.py +13 -0
  69. xoscar/metrics/backends/prometheus/prometheus_metric.py +70 -0
  70. xoscar/nvutils.py +717 -0
  71. xoscar/profiling.py +260 -0
  72. xoscar/serialization/__init__.py +20 -0
  73. xoscar/serialization/aio.py +141 -0
  74. xoscar/serialization/core.cpython-312-darwin.so +0 -0
  75. xoscar/serialization/core.pxd +28 -0
  76. xoscar/serialization/core.pyi +57 -0
  77. xoscar/serialization/core.pyx +944 -0
  78. xoscar/serialization/cuda.py +111 -0
  79. xoscar/serialization/exception.py +48 -0
  80. xoscar/serialization/mlx.py +67 -0
  81. xoscar/serialization/numpy.py +82 -0
  82. xoscar/serialization/pyfury.py +37 -0
  83. xoscar/serialization/scipy.py +72 -0
  84. xoscar/serialization/torch.py +180 -0
  85. xoscar/utils.py +522 -0
  86. xoscar/virtualenv/__init__.py +34 -0
  87. xoscar/virtualenv/core.py +268 -0
  88. xoscar/virtualenv/platform.py +56 -0
  89. xoscar/virtualenv/utils.py +100 -0
  90. xoscar/virtualenv/uv.py +321 -0
  91. xoscar-0.9.0.dist-info/METADATA +230 -0
  92. xoscar-0.9.0.dist-info/RECORD +94 -0
  93. xoscar-0.9.0.dist-info/WHEEL +6 -0
  94. xoscar-0.9.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,352 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import atexit
20
+ import copy
21
+ import logging
22
+ import threading
23
+ import weakref
24
+ from typing import Type, Union
25
+
26
+ from .._utils import Timer
27
+ from ..errors import ServerClosed
28
+ from ..profiling import get_profiling_data
29
+ from .communication import ChannelType, Client, UCXClient
30
+ from .message import (
31
+ DeserializeMessageFailed,
32
+ ErrorMessage,
33
+ ForwardMessage,
34
+ MessageType,
35
+ ResultMessage,
36
+ _MessageBase,
37
+ )
38
+ from .router import Router
39
+
40
+ ResultMessageType = Union[ResultMessage, ErrorMessage]
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ class ActorCallerThreadLocal:
45
+ __slots__ = ("_client_to_message_futures", "_clients", "_profiling_data")
46
+
47
+ _client_to_message_futures: dict[Client, dict[bytes, asyncio.Future]]
48
+ _clients: dict[Client, asyncio.Task]
49
+
50
+ def __init__(self):
51
+ self._client_to_message_futures = dict()
52
+ self._clients = dict()
53
+ self._profiling_data = get_profiling_data()
54
+
55
+ def _listen_client(self, client: Client):
56
+ if client not in self._clients:
57
+ self._clients[client] = asyncio.create_task(self._listen(client))
58
+ self._client_to_message_futures[client] = dict()
59
+ client_count = len(self._clients)
60
+ if client_count >= 100: # pragma: no cover
61
+ if (client_count - 100) % 10 == 0: # pragma: no cover
62
+ logger.warning(
63
+ "Actor caller has created too many clients (%s >= 100), "
64
+ "the global router may not be set.",
65
+ client_count,
66
+ )
67
+
68
+ async def get_client_via_type(
69
+ self, router: Router, dest_address: str, client_type: Type[Client]
70
+ ) -> Client:
71
+ client = await router.get_client_via_type(
72
+ dest_address, client_type, from_who=self
73
+ )
74
+ self._listen_client(client)
75
+ return client
76
+
77
+ async def get_client(
78
+ self,
79
+ router: Router,
80
+ dest_address: str,
81
+ proxy_addresses: list[str] | None = None,
82
+ ) -> Client:
83
+ client = await router.get_client(
84
+ dest_address, from_who=self, proxy_addresses=proxy_addresses
85
+ )
86
+ self._listen_client(client)
87
+ return client
88
+
89
+ async def _listen(self, client: Client):
90
+ try:
91
+ while not client.closed:
92
+ try:
93
+ try:
94
+ message: _MessageBase = await client.recv()
95
+ except (EOFError, ConnectionError, BrokenPipeError) as e:
96
+ # AssertionError is from get_header
97
+ # remote server closed, close client and raise ServerClosed
98
+ logger.debug(f"{client.dest_address} close due to {e}")
99
+ try:
100
+ await client.close()
101
+ except (ConnectionError, BrokenPipeError):
102
+ # close failed, ignore it
103
+ pass
104
+ raise ServerClosed(
105
+ f"Remote server {client.dest_address} closed: {e}"
106
+ ) from None
107
+ future = self._client_to_message_futures[client].pop(
108
+ message.message_id
109
+ )
110
+ if not future.done():
111
+ future.set_result(message)
112
+ except DeserializeMessageFailed as e:
113
+ message_id = e.message_id
114
+ future = self._client_to_message_futures[client].pop(message_id)
115
+ future.set_exception(e.__cause__) # type: ignore
116
+ except Exception as e: # noqa: E722 # pylint: disable=bare-except
117
+ message_futures = self._client_to_message_futures[client]
118
+ self._client_to_message_futures[client] = dict()
119
+ for future in message_futures.values():
120
+ future.set_exception(copy.copy(e))
121
+ finally:
122
+ # message may have Ray ObjectRef, delete it early in case next loop doesn't run
123
+ # as soon as expected.
124
+ try:
125
+ del message
126
+ except NameError:
127
+ pass
128
+ try:
129
+ del future
130
+ except NameError:
131
+ pass
132
+ await asyncio.sleep(0)
133
+
134
+ message_futures = self._client_to_message_futures[client]
135
+ self._client_to_message_futures[client] = dict()
136
+ error = ServerClosed(f"Remote server {client.dest_address} closed")
137
+ for future in message_futures.values():
138
+ future.set_exception(copy.copy(error))
139
+ finally:
140
+ try:
141
+ await client.close()
142
+ except: # noqa: E722 # nosec # pylint: disable=bare-except
143
+ # ignore all error if fail to close at last
144
+ pass
145
+
146
+ async def call_with_client(
147
+ self, client: Client, message: _MessageBase, wait: bool = True
148
+ ) -> ResultMessage | ErrorMessage | asyncio.Future:
149
+ loop = asyncio.get_running_loop()
150
+ wait_response = loop.create_future()
151
+ self._client_to_message_futures[client][message.message_id] = wait_response
152
+
153
+ with Timer() as timer:
154
+ try:
155
+ await client.send(message)
156
+ except ConnectionError:
157
+ try:
158
+ await client.close()
159
+ except ConnectionError:
160
+ # close failed, ignore it
161
+ pass
162
+ raise ServerClosed(f"Remote server {client.dest_address} closed")
163
+
164
+ if not wait:
165
+ r = wait_response
166
+ else:
167
+ r = await wait_response
168
+
169
+ self._profiling_data.collect_actor_call(message, timer.duration)
170
+ return r
171
+
172
+ async def call_send_buffers(
173
+ self,
174
+ client: UCXClient,
175
+ local_buffers: list,
176
+ meta_message: _MessageBase,
177
+ wait: bool = True,
178
+ ) -> ResultMessage | ErrorMessage | asyncio.Future:
179
+ loop = asyncio.get_running_loop()
180
+ wait_response = loop.create_future()
181
+ self._client_to_message_futures[client][meta_message.message_id] = wait_response
182
+
183
+ with Timer() as timer:
184
+ try:
185
+ await client.send_buffers(local_buffers, meta_message)
186
+ except ConnectionError: # pragma: no cover
187
+ try:
188
+ await client.close()
189
+ except ConnectionError:
190
+ # close failed, ignore it
191
+ pass
192
+ raise ServerClosed(f"Remote server {client.dest_address} closed")
193
+
194
+ if not wait: # pragma: no cover
195
+ r = wait_response
196
+ else:
197
+ r = await wait_response
198
+
199
+ self._profiling_data.collect_actor_call(meta_message, timer.duration)
200
+ return r
201
+
202
+ async def call(
203
+ self,
204
+ router: Router,
205
+ dest_address: str,
206
+ message: _MessageBase,
207
+ wait: bool = True,
208
+ proxy_addresses: list[str] | None = None,
209
+ ) -> ResultMessage | ErrorMessage | asyncio.Future:
210
+ client = await self.get_client(
211
+ router, dest_address, proxy_addresses=proxy_addresses
212
+ )
213
+ if (
214
+ client.channel_type == ChannelType.remote
215
+ and client.dest_address != dest_address
216
+ and message.message_type != MessageType.control
217
+ ):
218
+ # wrap message with forward message
219
+ message = ForwardMessage(
220
+ message_id=message.message_id, address=dest_address, raw_message=message
221
+ )
222
+ return await self.call_with_client(client, message, wait)
223
+
224
+ async def stop(self):
225
+ try:
226
+ await asyncio.gather(*[client.close() for client in self._clients])
227
+ except (ConnectionError, ServerClosed):
228
+ pass
229
+ try:
230
+ self.cancel_tasks()
231
+ except:
232
+ pass
233
+
234
+ def cancel_tasks(self):
235
+ # cancel listening for all clients
236
+ _ = [task.cancel() for task in self._clients.values()]
237
+
238
+
239
+ def _cancel_all_tasks(loop):
240
+ to_cancel = asyncio.all_tasks(loop)
241
+ if not to_cancel:
242
+ return
243
+
244
+ for task in to_cancel:
245
+ task.cancel()
246
+
247
+ # In Python 3.13+, we need to use a different approach to avoid deadlocks
248
+ # when shutting down event loops in threads
249
+ if hasattr(asyncio, "run"):
250
+ # For Python 3.13+, use a more robust approach
251
+ async def _gather_cancelled():
252
+ await asyncio.gather(*to_cancel, return_exceptions=True)
253
+
254
+ try:
255
+ # Try to run the gather in the current loop context
256
+ if loop.is_running():
257
+ # If loop is running, schedule the gather
258
+ asyncio.run_coroutine_threadsafe(_gather_cancelled(), loop)
259
+ else:
260
+ # If loop is not running, we can run it directly
261
+ loop.run_until_complete(_gather_cancelled())
262
+ except RuntimeError:
263
+ # If we can't run the gather, just log and continue
264
+ logger.debug("Could not gather cancelled tasks during shutdown")
265
+ else:
266
+ # For older Python versions, use the original approach
267
+ loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
268
+
269
+ for task in to_cancel:
270
+ if task.cancelled():
271
+ continue
272
+ if task.exception() is not None:
273
+ loop.call_exception_handler(
274
+ {
275
+ "message": "unhandled exception during asyncio.run() shutdown",
276
+ "exception": task.exception(),
277
+ "task": task,
278
+ }
279
+ )
280
+
281
+
282
+ def _safe_run_forever(loop):
283
+ try:
284
+ loop.run_forever()
285
+ finally:
286
+ try:
287
+ _cancel_all_tasks(loop)
288
+ except Exception as e:
289
+ logger.debug("Error during task cancellation: %s", e)
290
+ finally:
291
+ try:
292
+ loop.stop()
293
+ except Exception as e:
294
+ logger.debug("Error stopping loop: %s", e)
295
+
296
+
297
+ class ActorCaller:
298
+ __slots__ = "_thread_local"
299
+
300
+ class _RefHolder:
301
+ pass
302
+
303
+ _close_loop = None
304
+ _close_thread = None
305
+ _initialized = False
306
+
307
+ @classmethod
308
+ def _ensure_initialized(cls):
309
+ if not cls._initialized:
310
+ cls._close_loop = asyncio.new_event_loop()
311
+ cls._close_thread = threading.Thread(
312
+ target=_safe_run_forever, args=(cls._close_loop,), daemon=True
313
+ )
314
+ cls._close_thread.start()
315
+ atexit.register(cls._cleanup)
316
+ cls._initialized = True
317
+
318
+ @classmethod
319
+ def _cleanup(cls):
320
+ if cls._close_loop and cls._close_loop.is_running():
321
+ try:
322
+ cls._close_loop.call_soon_threadsafe(cls._close_loop.stop)
323
+ # Give the loop a moment to stop
324
+ if cls._close_thread:
325
+ cls._close_thread.join(timeout=0.5) # Shorter timeout for tests
326
+ except Exception as e:
327
+ logger.debug("Error during cleanup: %s", e)
328
+
329
+ def __init__(self):
330
+ self._thread_local = threading.local()
331
+
332
+ def __getattr__(self, item):
333
+ try:
334
+ actor_caller = self._thread_local.actor_caller
335
+ except AttributeError:
336
+ thread_info = str(threading.current_thread())
337
+ logger.debug("Creating a new actor caller for thread: %s", thread_info)
338
+ actor_caller = self._thread_local.actor_caller = ActorCallerThreadLocal()
339
+ ref = self._thread_local.ref = ActorCaller._RefHolder()
340
+ # If the thread exit, we clean the related actor callers and channels.
341
+
342
+ def _cleanup():
343
+ self._ensure_initialized()
344
+ # Use the background thread for cleanup
345
+ asyncio.run_coroutine_threadsafe(actor_caller.stop(), self._close_loop)
346
+ logger.debug(
347
+ "Clean up the actor caller due to thread exit: %s", thread_info
348
+ )
349
+
350
+ weakref.finalize(ref, _cleanup)
351
+
352
+ return getattr(actor_caller, item)
@@ -0,0 +1,16 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .backend import IndigenActorBackend
@@ -0,0 +1,19 @@
1
+ if __name__ == "__main__":
2
+ import click
3
+
4
+ @click.group(
5
+ invoke_without_command=True,
6
+ name="xoscar",
7
+ help="Xoscar command-line interface.",
8
+ )
9
+ def main():
10
+ pass
11
+
12
+ @main.command("start_sub_pool", help="Start a sub pool.")
13
+ @click.option("shm_name", "-sn", type=str, help="Shared memory name.")
14
+ def start_sub_pool(shm_name):
15
+ from xoscar.backends.indigen.pool import MainActorPool
16
+
17
+ MainActorPool._start_sub_pool_in_child(shm_name)
18
+
19
+ main()
@@ -0,0 +1,51 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ from ...backend import BaseActorBackend, register_backend
19
+ from ..context import IndigenActorContext
20
+ from .driver import IndigenActorDriver
21
+ from .pool import MainActorPool
22
+
23
+ __all__ = ["IndigenActorBackend"]
24
+
25
+
26
+ @register_backend
27
+ class IndigenActorBackend(BaseActorBackend):
28
+ @staticmethod
29
+ def name():
30
+ # None means Indigen is default scheme
31
+ # ucx can be recognized as Indigen backend as well
32
+ return [None, "ucx"]
33
+
34
+ @staticmethod
35
+ def get_context_cls():
36
+ return IndigenActorContext
37
+
38
+ @staticmethod
39
+ def get_driver_cls():
40
+ return IndigenActorDriver
41
+
42
+ @classmethod
43
+ async def create_actor_pool(
44
+ cls, address: str, n_process: int | None = None, **kwargs
45
+ ):
46
+ from ..pool import create_actor_pool
47
+
48
+ assert n_process is not None
49
+ return await create_actor_pool(
50
+ address, pool_cls=MainActorPool, n_process=n_process, **kwargs
51
+ )
@@ -0,0 +1,26 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from numbers import Number
17
+ from typing import Dict
18
+
19
+ from ...driver import BaseActorDriver
20
+
21
+
22
+ class IndigenActorDriver(BaseActorDriver):
23
+ @classmethod
24
+ def setup_cluster(cls, address_to_resources: Dict[str, Dict[str, Number]]):
25
+ # nothing need to be done in driver of Indigen backend
26
+ pass
@@ -0,0 +1,221 @@
1
+ # Copyright 2017 The Ray Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http:#www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import logging
17
+ import subprocess
18
+ import sys
19
+
20
+ import psutil
21
+
22
+ # Linux can bind child processes' lifetimes to that of their parents via prctl.
23
+ # prctl support is detected dynamically once, and assumed thereafter.
24
+ linux_prctl = None
25
+
26
+ # Windows can bind processes' lifetimes to that of kernel-level "job objects".
27
+ # We keep a global job object to tie its lifetime to that of our own process.
28
+ win32_job = None
29
+ win32_AssignProcessToJobObject = None
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ def detect_fate_sharing_support_win32():
35
+ global win32_job, win32_AssignProcessToJobObject
36
+ if win32_job is None and sys.platform == "win32":
37
+ import ctypes
38
+
39
+ try:
40
+ from ctypes.wintypes import BOOL, DWORD, HANDLE, LPCWSTR, LPVOID
41
+
42
+ kernel32 = ctypes.WinDLL("kernel32")
43
+ kernel32.CreateJobObjectW.argtypes = (LPVOID, LPCWSTR)
44
+ kernel32.CreateJobObjectW.restype = HANDLE
45
+ sijo_argtypes = (HANDLE, ctypes.c_int, LPVOID, DWORD)
46
+ kernel32.SetInformationJobObject.argtypes = sijo_argtypes
47
+ kernel32.SetInformationJobObject.restype = BOOL
48
+ kernel32.AssignProcessToJobObject.argtypes = (HANDLE, HANDLE)
49
+ kernel32.AssignProcessToJobObject.restype = BOOL
50
+ kernel32.IsDebuggerPresent.argtypes = ()
51
+ kernel32.IsDebuggerPresent.restype = BOOL
52
+ except (AttributeError, TypeError, ImportError):
53
+ kernel32 = None
54
+ job = kernel32.CreateJobObjectW(None, None) if kernel32 else None
55
+ job = subprocess.Handle(job) if job else job
56
+ if job:
57
+ from ctypes.wintypes import DWORD, LARGE_INTEGER, ULARGE_INTEGER
58
+
59
+ class JOBOBJECT_BASIC_LIMIT_INFORMATION(ctypes.Structure):
60
+ _fields_ = [
61
+ ("PerProcessUserTimeLimit", LARGE_INTEGER),
62
+ ("PerJobUserTimeLimit", LARGE_INTEGER),
63
+ ("LimitFlags", DWORD),
64
+ ("MinimumWorkingSetSize", ctypes.c_size_t),
65
+ ("MaximumWorkingSetSize", ctypes.c_size_t),
66
+ ("ActiveProcessLimit", DWORD),
67
+ ("Affinity", ctypes.c_size_t),
68
+ ("PriorityClass", DWORD),
69
+ ("SchedulingClass", DWORD),
70
+ ]
71
+
72
+ class IO_COUNTERS(ctypes.Structure):
73
+ _fields_ = [
74
+ ("ReadOperationCount", ULARGE_INTEGER),
75
+ ("WriteOperationCount", ULARGE_INTEGER),
76
+ ("OtherOperationCount", ULARGE_INTEGER),
77
+ ("ReadTransferCount", ULARGE_INTEGER),
78
+ ("WriteTransferCount", ULARGE_INTEGER),
79
+ ("OtherTransferCount", ULARGE_INTEGER),
80
+ ]
81
+
82
+ class JOBOBJECT_EXTENDED_LIMIT_INFORMATION(ctypes.Structure):
83
+ _fields_ = [
84
+ ("BasicLimitInformation", JOBOBJECT_BASIC_LIMIT_INFORMATION),
85
+ ("IoInfo", IO_COUNTERS),
86
+ ("ProcessMemoryLimit", ctypes.c_size_t),
87
+ ("JobMemoryLimit", ctypes.c_size_t),
88
+ ("PeakProcessMemoryUsed", ctypes.c_size_t),
89
+ ("PeakJobMemoryUsed", ctypes.c_size_t),
90
+ ]
91
+
92
+ debug = kernel32.IsDebuggerPresent()
93
+
94
+ # Defined in <WinNT.h>; also available here:
95
+ # https://docs.microsoft.com/en-us/windows/win32/api/jobapi2/nf-jobapi2-setinformationjobobject
96
+ JobObjectExtendedLimitInformation = 9
97
+ JOB_OBJECT_LIMIT_BREAKAWAY_OK = 0x00000800
98
+ JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION = 0x00000400
99
+ JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE = 0x00002000
100
+ buf = JOBOBJECT_EXTENDED_LIMIT_INFORMATION()
101
+ buf.BasicLimitInformation.LimitFlags = (
102
+ (0 if debug else JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE)
103
+ | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION
104
+ | JOB_OBJECT_LIMIT_BREAKAWAY_OK
105
+ )
106
+ infoclass = JobObjectExtendedLimitInformation
107
+ if not kernel32.SetInformationJobObject(
108
+ job, infoclass, ctypes.byref(buf), ctypes.sizeof(buf)
109
+ ):
110
+ job = None
111
+ win32_AssignProcessToJobObject = (
112
+ kernel32.AssignProcessToJobObject if kernel32 is not None else False
113
+ )
114
+ win32_job = job if job else False
115
+ return bool(win32_job)
116
+
117
+
118
+ def detect_fate_sharing_support_linux():
119
+ global linux_prctl
120
+ if linux_prctl is None and sys.platform.startswith("linux"):
121
+ try:
122
+ from ctypes import CDLL, c_int, c_ulong
123
+
124
+ prctl = CDLL(None).prctl
125
+ prctl.restype = c_int
126
+ prctl.argtypes = [c_int, c_ulong, c_ulong, c_ulong, c_ulong]
127
+ except (AttributeError, TypeError):
128
+ prctl = None
129
+ linux_prctl = prctl if prctl else False
130
+ return bool(linux_prctl)
131
+
132
+
133
+ def detect_fate_sharing_support():
134
+ result = None
135
+ if sys.platform == "win32":
136
+ result = detect_fate_sharing_support_win32()
137
+ elif sys.platform.startswith("linux"):
138
+ result = detect_fate_sharing_support_linux()
139
+ return result
140
+
141
+
142
+ if detect_fate_sharing_support():
143
+ logger.info("Using kernel-level fate-sharing.")
144
+
145
+
146
+ def set_kill_on_parent_death_linux():
147
+ """Ensures this process dies if its parent dies (fate-sharing).
148
+
149
+ Linux-only. Must be called in preexec_fn (i.e. by the child).
150
+ """
151
+ if detect_fate_sharing_support_linux():
152
+ import signal
153
+
154
+ PR_SET_PDEATHSIG = 1
155
+ if linux_prctl(PR_SET_PDEATHSIG, signal.SIGKILL, 0, 0, 0) != 0:
156
+ import ctypes
157
+
158
+ raise OSError(ctypes.get_errno(), "prctl(PR_SET_PDEATHSIG) failed")
159
+ else:
160
+ assert False, "PR_SET_PDEATHSIG used despite being unavailable"
161
+
162
+
163
+ def set_kill_child_on_death_win32(child_proc):
164
+ """Ensures the child process dies if this process dies (fate-sharing).
165
+
166
+ Windows-only. Must be called by the parent, after spawning the child.
167
+
168
+ Args:
169
+ child_proc: The subprocess.Popen or subprocess.Handle object.
170
+ """
171
+
172
+ if isinstance(child_proc, subprocess.Popen):
173
+ child_proc = child_proc._handle
174
+ assert isinstance(child_proc, subprocess.Handle)
175
+
176
+ if detect_fate_sharing_support_win32():
177
+ if not win32_AssignProcessToJobObject(win32_job, int(child_proc)):
178
+ import ctypes
179
+
180
+ raise OSError(ctypes.get_last_error(), "AssignProcessToJobObject() failed")
181
+ else:
182
+ assert False, "AssignProcessToJobObject used despite being unavailable"
183
+
184
+
185
+ def preexec_fn():
186
+ import signal
187
+
188
+ signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGINT})
189
+ if sys.platform.startswith("linux"):
190
+ set_kill_on_parent_death_linux()
191
+
192
+
193
+ async def create_subprocess_exec(*args, **kwargs):
194
+ win32 = sys.platform == "win32"
195
+ # With Windows fate-sharing, we need special care:
196
+ # The process must be added to the job before it is allowed to execute.
197
+ # Otherwise, there's a race condition: the process might spawn children
198
+ # before the process itself is assigned to the job.
199
+ # After that point, its children will not be added to the job anymore.
200
+ CREATE_SUSPENDED = 0x00000004 # from Windows headers
201
+ creationflags = CREATE_SUSPENDED if win32 else 0
202
+ if win32 and detect_fate_sharing_support_win32():
203
+ creationflags |= subprocess.CREATE_NEW_PROCESS_GROUP
204
+ # CREATE_NEW_PROCESS_GROUP is used to send Ctrl+C on Windows:
205
+ # https://docs.python.org/3/library/subprocess.html#subprocess.Popen.send_signal
206
+ process: asyncio.subprocess.Process = await asyncio.create_subprocess_exec(
207
+ *args,
208
+ **kwargs,
209
+ preexec_fn=preexec_fn if not win32 else None,
210
+ creationflags=creationflags,
211
+ )
212
+ if win32:
213
+ proc = process._transport._proc
214
+ try:
215
+ set_kill_child_on_death_win32(proc)
216
+ psutil.Process(process.pid).resume()
217
+ except (psutil.Error, OSError):
218
+ logger.exception("Resume process failed, kill %s.", process.pid)
219
+ process.kill()
220
+ raise
221
+ return process