xoscar 0.6.1__cp310-cp310-win_amd64.whl → 0.7.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xoscar might be problematic. Click here for more details.

@@ -16,26 +16,26 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import asyncio
19
- import atexit
20
- import concurrent.futures as futures
19
+ import asyncio.subprocess
21
20
  import configparser
22
- import contextlib
23
21
  import itertools
24
22
  import logging.config
25
- import multiprocessing
26
23
  import os
24
+ import pickle
27
25
  import random
28
26
  import signal
27
+ import struct
29
28
  import sys
30
29
  import threading
30
+ import time
31
31
  import uuid
32
- from dataclasses import dataclass
33
- from multiprocessing import util
34
- from types import TracebackType
32
+ from enum import IntEnum
35
33
  from typing import List, Optional
36
34
 
35
+ import psutil
36
+
37
37
  from ..._utils import reset_id_random_seed
38
- from ...utils import dataslots, ensure_coverage
38
+ from ...utils import ensure_coverage
39
39
  from ..config import ActorPoolConfig
40
40
  from ..message import (
41
41
  ControlMessage,
@@ -44,94 +44,40 @@ from ..message import (
44
44
  new_message_id,
45
45
  )
46
46
  from ..pool import MainActorPoolBase, SubActorPoolBase, _register_message_handler
47
+ from . import shared_memory
48
+ from .fate_sharing import create_subprocess_exec
47
49
 
50
+ _SUBPROCESS_SHM_SIZE = 10240
48
51
  _is_windows: bool = sys.platform.startswith("win")
49
52
 
50
- if sys.version_info[:2] == (3, 9):
51
- # fix for Python 3.9, see https://bugs.python.org/issue43517
52
- if sys.platform == "win32":
53
- from multiprocessing import popen_spawn_win32 as popen_spawn
54
-
55
- popen_forkserver = popen_fork = synchronize = None
56
- else:
57
- from multiprocessing import popen_fork, popen_forkserver
58
- from multiprocessing import popen_spawn_posix as popen_spawn
59
- from multiprocessing import synchronize
60
- _ = popen_spawn, popen_forkserver, popen_fork, synchronize
61
- del _
62
- elif sys.version_info[:2] == (3, 6): # pragma: no cover
63
- from multiprocessing.process import BaseProcess
64
-
65
- # define kill method for multiprocessing
66
- def _mp_kill(self):
67
- if not _is_windows:
68
- try:
69
- os.kill(self.pid, signal.SIGKILL)
70
- except ProcessLookupError:
71
- pass
72
- except OSError:
73
- if self.wait(timeout=0.1) is None:
74
- raise
75
- else:
76
- self.terminate()
77
-
78
- BaseProcess.kill = _mp_kill
79
-
80
53
  logger = logging.getLogger(__name__)
81
- _init_main_suspended_local = threading.local()
82
54
 
83
55
 
84
- def _terminate_children():
85
- for c in multiprocessing.active_children():
86
- try:
87
- c.terminate()
88
- except Exception:
89
- pass
90
-
56
+ class _ShmSeq(IntEnum):
57
+ INIT_PARAMS = 1
58
+ INIT_RESULT = 2
91
59
 
92
- if util:
93
- # Import multiprocessing.util to register _exit_function at exit.
94
- atexit.register(_terminate_children)
95
60
 
61
+ def _shm_put_object(seq: _ShmSeq, shm: shared_memory.SharedMemory, o: object):
62
+ serialized = pickle.dumps(o)
63
+ assert (
64
+ len(serialized) < _SUBPROCESS_SHM_SIZE - 8
65
+ ), f"Serialized object {o} is too long."
66
+ shm.buf[4:12] = struct.pack("<II", sys.hexversion, len(serialized))
67
+ shm.buf[12 : 12 + len(serialized)] = serialized
68
+ shm.buf[:4] = struct.pack("<I", seq)
96
69
 
97
- def _patch_spawn_get_preparation_data():
98
- try:
99
- from multiprocessing import spawn as mp_spawn
100
70
 
101
- _raw_get_preparation_data = mp_spawn.get_preparation_data
102
-
103
- def _patched_get_preparation_data(*args, **kw):
104
- ret = _raw_get_preparation_data(*args, **kw)
105
- if getattr(_init_main_suspended_local, "value", False):
106
- # make sure user module is not imported when start cluster
107
- ret.pop("init_main_from_name", None)
108
- ret.pop("init_main_from_path", None)
109
- return ret
110
-
111
- _patched_get_preparation_data._indigen_patched = True
112
- if not getattr(mp_spawn.get_preparation_data, "_indigen_patched", False):
113
- mp_spawn.get_preparation_data = _patched_get_preparation_data
114
- except (ImportError, AttributeError): # pragma: no cover
115
- pass
116
-
117
-
118
- @contextlib.contextmanager
119
- def _suspend_init_main():
120
- try:
121
- _init_main_suspended_local.value = True
122
- yield
123
- finally:
124
- _init_main_suspended_local.value = False
125
-
126
-
127
- @dataslots
128
- @dataclass
129
- class SubpoolStatus:
130
- # for status, 0 is succeeded, 1 is failed
131
- status: int | None = None
132
- external_addresses: List[str] | None = None
133
- error: BaseException | None = None
134
- traceback: TracebackType | None = None
71
+ def _shm_get_object(seq: _ShmSeq, shm: shared_memory.SharedMemory):
72
+ recv_seq = struct.unpack("<I", shm.buf[:4])[0]
73
+ if recv_seq != seq:
74
+ return
75
+ python_version_hex, size = struct.unpack("<II", shm.buf[4:12])
76
+ if python_version_hex != sys.hexversion:
77
+ raise RuntimeError(
78
+ f"Python version mismatch, sender: {python_version_hex}, receiver: {sys.hexversion}"
79
+ )
80
+ return pickle.loads(shm.buf[12 : 12 + size])
135
81
 
136
82
 
137
83
  @_register_message_handler
@@ -197,49 +143,21 @@ class MainActorPool(MainActorPoolBase):
197
143
  cls,
198
144
  actor_pool_config: ActorPoolConfig,
199
145
  process_index: int,
200
- start_method: str | None = None,
146
+ start_python: str | None = None,
201
147
  ):
202
- def start_pool_in_process():
203
- ctx = multiprocessing.get_context(method=start_method)
204
- status_queue = ctx.Queue()
205
- main_pool_pid = os.getpid()
206
-
207
- with _suspend_init_main():
208
- process = ctx.Process(
209
- target=cls._start_sub_pool,
210
- args=(
211
- actor_pool_config,
212
- process_index,
213
- status_queue,
214
- main_pool_pid,
215
- ),
216
- name=f"IndigenActorPool{process_index}",
217
- )
218
- process.start()
219
-
220
- # wait for sub actor pool to finish starting
221
- process_status = status_queue.get()
222
- return process, process_status
223
-
224
- _patch_spawn_get_preparation_data()
225
- loop = asyncio.get_running_loop()
226
- with futures.ThreadPoolExecutor(1) as executor:
227
- create_pool_task = loop.run_in_executor(executor, start_pool_in_process)
228
- return await create_pool_task
148
+ return await cls._create_sub_pool_from_parent(
149
+ actor_pool_config, process_index, start_python
150
+ )
229
151
 
230
152
  @classmethod
231
153
  async def wait_sub_pools_ready(cls, create_pool_tasks: List[asyncio.Task]):
232
- processes: list[multiprocessing.Process] = []
154
+ processes: list[asyncio.subprocess.Process] = []
233
155
  ext_addresses = []
234
156
  error = None
235
157
  for task in create_pool_tasks:
236
- process, status = await task
158
+ process, address = await task
237
159
  processes.append(process)
238
- if status.status == 1:
239
- # start sub pool failed
240
- error = status.error.with_traceback(status.traceback)
241
- else:
242
- ext_addresses.append(status.external_addresses)
160
+ ext_addresses.append(address)
243
161
  if error:
244
162
  for p in processes:
245
163
  # error happens, kill all subprocesses
@@ -248,84 +166,158 @@ class MainActorPool(MainActorPoolBase):
248
166
  return processes, ext_addresses
249
167
 
250
168
  @classmethod
251
- def _start_sub_pool(
169
+ def _start_sub_pool_in_child(
252
170
  cls,
253
- actor_config: ActorPoolConfig,
254
- process_index: int,
255
- status_queue: multiprocessing.Queue,
256
- main_pool_pid: int,
171
+ shm_name: str,
257
172
  ):
258
173
  ensure_coverage()
259
174
 
260
- # make sure enough randomness for every sub pool
261
- random.seed(uuid.uuid1().bytes)
262
- reset_id_random_seed()
263
-
264
- conf = actor_config.get_pool_config(process_index)
265
- suspend_sigint = conf["suspend_sigint"]
266
- if suspend_sigint:
267
- signal.signal(signal.SIGINT, lambda *_: None)
268
-
269
- logging_conf = conf["logging_conf"] or {}
270
- if isinstance(logging_conf, configparser.RawConfigParser):
271
- logging.config.fileConfig(logging_conf)
272
- elif logging_conf.get("dict"):
273
- logging.config.dictConfig(logging_conf["dict"])
274
- elif logging_conf.get("file"):
275
- logging.config.fileConfig(logging_conf["file"])
276
- elif logging_conf.get("level"):
277
- logging.getLogger("__main__").setLevel(logging_conf["level"])
278
- logging.getLogger("xoscar").setLevel(logging_conf["level"])
279
- if logging_conf.get("format"):
280
- logging.basicConfig(format=logging_conf["format"])
281
-
282
- use_uvloop = conf["use_uvloop"]
283
- if use_uvloop:
284
- import uvloop
285
-
286
- asyncio.set_event_loop(uvloop.new_event_loop())
287
- else:
288
- asyncio.set_event_loop(asyncio.new_event_loop())
175
+ shm = shared_memory.SharedMemory(shm_name, track=False)
176
+ try:
177
+ config = _shm_get_object(_ShmSeq.INIT_PARAMS, shm)
178
+ actor_config = config["actor_pool_config"]
179
+ process_index = config["process_index"]
180
+ main_pool_pid = config["main_pool_pid"]
181
+
182
+ def _check_ppid():
183
+ while True:
184
+ try:
185
+ # We can't simply check if the os.getppid() equals with main_pool_pid,
186
+ # as the double fork may result in a new process as the parent.
187
+ psutil.Process(main_pool_pid)
188
+ except psutil.NoSuchProcess:
189
+ logger.info("Exit due to main pool %s exit.", main_pool_pid)
190
+ os._exit(0)
191
+ except Exception as e:
192
+ logger.exception("Check ppid failed: %s", e)
193
+ time.sleep(10)
194
+
195
+ t = threading.Thread(target=_check_ppid, daemon=True)
196
+ t.start()
197
+
198
+ # make sure enough randomness for every sub pool
199
+ random.seed(uuid.uuid1().bytes)
200
+ reset_id_random_seed()
201
+
202
+ conf = actor_config.get_pool_config(process_index)
203
+ suspend_sigint = conf["suspend_sigint"]
204
+ if suspend_sigint:
205
+ signal.signal(signal.SIGINT, lambda *_: None)
206
+
207
+ logging_conf = conf["logging_conf"] or {}
208
+ if isinstance(logging_conf, configparser.RawConfigParser):
209
+ logging.config.fileConfig(logging_conf)
210
+ elif logging_conf.get("dict"):
211
+ logging.config.dictConfig(logging_conf["dict"])
212
+ elif logging_conf.get("file"):
213
+ logging.config.fileConfig(logging_conf["file"])
214
+ elif logging_conf.get("level"):
215
+ logging.getLogger("__main__").setLevel(logging_conf["level"])
216
+ logging.getLogger("xoscar").setLevel(logging_conf["level"])
217
+ if logging_conf.get("format"):
218
+ logging.basicConfig(format=logging_conf["format"])
219
+
220
+ use_uvloop = conf["use_uvloop"]
221
+ if use_uvloop:
222
+ import uvloop
223
+
224
+ asyncio.set_event_loop(uvloop.new_event_loop())
225
+ else:
226
+ asyncio.set_event_loop(asyncio.new_event_loop())
289
227
 
290
- coro = cls._create_sub_pool(
291
- actor_config, process_index, status_queue, main_pool_pid
292
- )
293
- asyncio.run(coro)
228
+ coro = cls._create_sub_pool(actor_config, process_index, main_pool_pid, shm)
229
+ asyncio.run(coro)
230
+ finally:
231
+ shm.close()
294
232
 
295
233
  @classmethod
296
234
  async def _create_sub_pool(
297
235
  cls,
298
236
  actor_config: ActorPoolConfig,
299
237
  process_index: int,
300
- status_queue: multiprocessing.Queue,
301
238
  main_pool_pid: int,
239
+ shm: shared_memory.SharedMemory,
240
+ ):
241
+ cur_pool_config = actor_config.get_pool_config(process_index)
242
+ env = cur_pool_config["env"]
243
+ if env:
244
+ os.environ.update(env)
245
+ pool = await SubActorPool.create(
246
+ {
247
+ "actor_pool_config": actor_config,
248
+ "process_index": process_index,
249
+ "main_pool_pid": main_pool_pid,
250
+ }
251
+ )
252
+ await pool.start()
253
+ _shm_put_object(_ShmSeq.INIT_RESULT, shm, cur_pool_config["external_address"])
254
+ await pool.join()
255
+
256
+ @staticmethod
257
+ async def _create_sub_pool_from_parent(
258
+ actor_pool_config: ActorPoolConfig,
259
+ process_index: int,
260
+ start_python: str | None = None,
302
261
  ):
303
- process_status = None
262
+ # We check the Python version in _shm_get_object to make it faster,
263
+ # as in most cases the Python versions are the same.
264
+ if start_python is None:
265
+ start_python = sys.executable
266
+
267
+ external_addresses: List | None = None
268
+ shm = shared_memory.SharedMemory(
269
+ create=True, size=_SUBPROCESS_SHM_SIZE, track=False
270
+ )
304
271
  try:
305
- cur_pool_config = actor_config.get_pool_config(process_index)
306
- env = cur_pool_config["env"]
307
- if env:
308
- os.environ.update(env)
309
- pool = await SubActorPool.create(
272
+ _shm_put_object(
273
+ _ShmSeq.INIT_PARAMS,
274
+ shm,
310
275
  {
311
- "actor_pool_config": actor_config,
276
+ "actor_pool_config": actor_pool_config,
312
277
  "process_index": process_index,
313
- "main_pool_pid": main_pool_pid,
314
- }
278
+ "main_pool_pid": os.getpid(),
279
+ },
315
280
  )
316
- external_addresses = cur_pool_config["external_address"]
317
- process_status = SubpoolStatus(
318
- status=0, external_addresses=external_addresses
281
+ process = await create_subprocess_exec(
282
+ start_python,
283
+ "-m",
284
+ "xoscar.backends.indigen",
285
+ "start_sub_pool",
286
+ "-sn",
287
+ shm.name,
319
288
  )
320
- await pool.start()
321
- except: # noqa: E722 # nosec # pylint: disable=bare-except
322
- _, error, tb = sys.exc_info()
323
- process_status = SubpoolStatus(status=1, error=error, traceback=tb)
324
- raise
289
+
290
+ def _get_external_addresses():
291
+ try:
292
+ nonlocal external_addresses
293
+ while (
294
+ shm
295
+ and shm.buf is not None
296
+ and not (
297
+ external_addresses := _shm_get_object(
298
+ _ShmSeq.INIT_RESULT, shm
299
+ )
300
+ )
301
+ ):
302
+ time.sleep(0.1)
303
+ except asyncio.CancelledError:
304
+ pass
305
+
306
+ _, unfinished = await asyncio.wait(
307
+ [
308
+ asyncio.create_task(process.wait()),
309
+ asyncio.create_task(asyncio.to_thread(_get_external_addresses)),
310
+ ],
311
+ return_when=asyncio.FIRST_COMPLETED,
312
+ )
313
+ for t in unfinished:
314
+ t.cancel()
325
315
  finally:
326
- status_queue.put(process_status)
327
- status_queue.cancel_join_thread()
328
- await pool.join()
316
+ shm.close()
317
+ shm.unlink()
318
+ if external_addresses is None:
319
+ raise OSError("Start sub pool failed.")
320
+ return process, external_addresses
329
321
 
330
322
  async def append_sub_pool(
331
323
  self,
@@ -337,7 +329,7 @@ class MainActorPool(MainActorPoolBase):
337
329
  suspend_sigint: bool | None = None,
338
330
  use_uvloop: bool | None = None,
339
331
  logging_conf: dict | None = None,
340
- start_method: str | None = None,
332
+ start_python: str | None = None,
341
333
  kwargs: dict | None = None,
342
334
  ):
343
335
  # external_address has port 0, subprocess will bind random port.
@@ -376,32 +368,12 @@ class MainActorPool(MainActorPoolBase):
376
368
  kwargs,
377
369
  )
378
370
 
379
- def start_pool_in_process():
380
- ctx = multiprocessing.get_context(method=start_method)
381
- status_queue = ctx.Queue()
382
- main_pool_pid = os.getpid()
383
-
384
- with _suspend_init_main():
385
- process = ctx.Process(
386
- target=self._start_sub_pool,
387
- args=(self._config, process_index, status_queue, main_pool_pid),
388
- name=f"IndigenActorPool{process_index}",
389
- )
390
- process.start()
391
-
392
- # wait for sub actor pool to finish starting
393
- process_status = status_queue.get()
394
- return process, process_status
395
-
396
- loop = asyncio.get_running_loop()
397
- with futures.ThreadPoolExecutor(1) as executor:
398
- create_pool_task = loop.run_in_executor(executor, start_pool_in_process)
399
- process, process_status = await create_pool_task
400
-
401
- self._config.reset_pool_external_address(
402
- process_index, process_status.external_addresses[0]
371
+ process, external_addresses = await self._create_sub_pool_from_parent(
372
+ self._config, process_index, start_python
403
373
  )
404
- self.attach_sub_process(process_status.external_addresses[0], process)
374
+
375
+ self._config.reset_pool_external_address(process_index, external_addresses[0])
376
+ self.attach_sub_process(external_addresses[0], process)
405
377
 
406
378
  control_message = ControlMessage(
407
379
  message_id=new_message_id(),
@@ -411,16 +383,16 @@ class MainActorPool(MainActorPoolBase):
411
383
  )
412
384
  await self.handle_control_command(control_message)
413
385
  # The actual port will return in process_status.
414
- return process_status.external_addresses[0]
386
+ return external_addresses[0]
415
387
 
416
388
  async def remove_sub_pool(
417
389
  self, external_address: str, timeout: float | None = None, force: bool = False
418
390
  ):
419
391
  process = self.sub_processes[external_address]
420
392
  process_index = self._config.get_process_index(external_address)
421
- await self.stop_sub_pool(external_address, process, timeout, force)
422
393
  del self.sub_processes[external_address]
423
394
  self._config.remove_pool_config(process_index)
395
+ await self.stop_sub_pool(external_address, process, timeout, force)
424
396
 
425
397
  control_message = ControlMessage(
426
398
  message_id=new_message_id(),
@@ -431,42 +403,35 @@ class MainActorPool(MainActorPoolBase):
431
403
  await self.handle_control_command(control_message)
432
404
 
433
405
  async def kill_sub_pool(
434
- self, process: multiprocessing.Process, force: bool = False
406
+ self, process: asyncio.subprocess.Process, force: bool = False
435
407
  ):
408
+ try:
409
+ p = psutil.Process(process.pid)
410
+ except psutil.NoSuchProcess:
411
+ return
412
+
436
413
  if not force: # pragma: no cover
437
- # must shutdown gracefully, or subprocess created by model will not exit
438
- if not _is_windows:
439
- try:
440
- os.kill(process.pid, signal.SIGINT) # type: ignore
441
- except OSError: # pragma: no cover
442
- pass
443
- process.terminate() # SIGTERM
444
- wait_pool = futures.ThreadPoolExecutor(1)
414
+ p.terminate()
445
415
  try:
446
- loop = asyncio.get_running_loop()
447
- await loop.run_in_executor(wait_pool, process.join, 3)
448
- finally:
449
- wait_pool.shutdown(False)
450
- process.kill() # SIGKILL
451
- await asyncio.to_thread(process.join, 5)
452
-
453
- async def is_sub_pool_alive(self, process: multiprocessing.Process):
454
- try:
455
- return await asyncio.to_thread(process.is_alive)
456
- except RuntimeError as ex: # pragma: no cover
457
- if "cannot schedule new futures" not in str(ex):
458
- # when atexit is triggered, the default pool might be shutdown
459
- # and to_thread will fail
460
- raise
461
- return process.is_alive()
416
+ p.wait(5)
417
+ except psutil.TimeoutExpired:
418
+ pass
419
+
420
+ while p.is_running():
421
+ p.kill()
422
+ if not p.is_running():
423
+ return
424
+ logger.info("Sub pool can't be killed: %s", p)
425
+ time.sleep(0.1)
426
+
427
+ async def is_sub_pool_alive(self, process: asyncio.subprocess.Process):
428
+ return process.returncode is None
462
429
 
463
430
  async def recover_sub_pool(self, address: str):
464
431
  process_index = self._config.get_process_index(address)
465
432
  # process dead, restart it
466
433
  # remember always use spawn to recover sub pool
467
- task = asyncio.create_task(
468
- self.start_sub_pool(self._config, process_index, "spawn")
469
- )
434
+ task = asyncio.create_task(self.start_sub_pool(self._config, process_index))
470
435
  self.sub_processes[address] = (await self.wait_sub_pools_ready([task]))[0][0]
471
436
 
472
437
  if self._auto_recover == "actor":