xoscar 0.6.2__cp310-cp310-macosx_11_0_arm64.whl → 0.7.0rc1__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xoscar might be problematic. Click here for more details.

@@ -16,26 +16,26 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import asyncio
19
- import atexit
20
- import concurrent.futures as futures
19
+ import asyncio.subprocess
21
20
  import configparser
22
- import contextlib
23
21
  import itertools
24
22
  import logging.config
25
- import multiprocessing
26
23
  import os
24
+ import pickle
27
25
  import random
28
26
  import signal
27
+ import struct
29
28
  import sys
30
29
  import threading
30
+ import time
31
31
  import uuid
32
- from dataclasses import dataclass
33
- from multiprocessing import util
34
- from types import TracebackType
32
+ from enum import IntEnum
35
33
  from typing import List, Optional
36
34
 
35
+ import psutil
36
+
37
37
  from ..._utils import reset_id_random_seed
38
- from ...utils import dataslots, ensure_coverage
38
+ from ...utils import ensure_coverage
39
39
  from ..config import ActorPoolConfig
40
40
  from ..message import (
41
41
  ControlMessage,
@@ -44,119 +44,40 @@ from ..message import (
44
44
  new_message_id,
45
45
  )
46
46
  from ..pool import MainActorPoolBase, SubActorPoolBase, _register_message_handler
47
+ from . import shared_memory
48
+ from .fate_sharing import create_subprocess_exec
47
49
 
50
+ _SUBPROCESS_SHM_SIZE = 10240
48
51
  _is_windows: bool = sys.platform.startswith("win")
49
52
 
50
- if sys.version_info[:2] == (3, 9):
51
- # fix for Python 3.9, see https://bugs.python.org/issue43517
52
- if sys.platform == "win32":
53
- from multiprocessing import popen_spawn_win32 as popen_spawn
54
-
55
- popen_forkserver = popen_fork = synchronize = None
56
- else:
57
- from multiprocessing import popen_fork, popen_forkserver
58
- from multiprocessing import popen_spawn_posix as popen_spawn
59
- from multiprocessing import synchronize
60
- _ = popen_spawn, popen_forkserver, popen_fork, synchronize
61
- del _
62
- elif sys.version_info[:2] == (3, 6): # pragma: no cover
63
- from multiprocessing.process import BaseProcess
64
-
65
- # define kill method for multiprocessing
66
- def _mp_kill(self):
67
- if not _is_windows:
68
- try:
69
- os.kill(self.pid, signal.SIGKILL)
70
- except ProcessLookupError:
71
- pass
72
- except OSError:
73
- if self.wait(timeout=0.1) is None:
74
- raise
75
- else:
76
- self.terminate()
77
-
78
- BaseProcess.kill = _mp_kill
79
-
80
53
  logger = logging.getLogger(__name__)
81
- _init_main_suspended_local = threading.local()
82
-
83
-
84
- def _terminate_children():
85
- for c in multiprocessing.active_children():
86
- try:
87
- c.terminate()
88
- except Exception:
89
- pass
90
-
91
-
92
- if util:
93
- # Import multiprocessing.util to register _exit_function at exit.
94
- atexit.register(_terminate_children)
95
-
96
-
97
- def _patch_spawn_get_preparation_data():
98
- try:
99
- from multiprocessing import spawn as mp_spawn
100
-
101
- _raw_get_preparation_data = mp_spawn.get_preparation_data
102
-
103
- def _patched_get_preparation_data(*args, **kw):
104
- ret = _raw_get_preparation_data(*args, **kw)
105
- if getattr(_init_main_suspended_local, "value", False):
106
- # make sure user module is not imported when start cluster
107
- ret.pop("init_main_from_name", None)
108
- ret.pop("init_main_from_path", None)
109
- return ret
110
-
111
- _patched_get_preparation_data._indigen_patched = True
112
- if not getattr(mp_spawn.get_preparation_data, "_indigen_patched", False):
113
- mp_spawn.get_preparation_data = _patched_get_preparation_data
114
- except (ImportError, AttributeError): # pragma: no cover
115
- pass
116
-
117
54
 
118
- @contextlib.contextmanager
119
- def _suspend_init_main():
120
- try:
121
- _init_main_suspended_local.value = True
122
- yield
123
- finally:
124
- _init_main_suspended_local.value = False
125
55
 
56
+ class _ShmSeq(IntEnum):
57
+ INIT_PARAMS = 1
58
+ INIT_RESULT = 2
126
59
 
127
- @dataslots
128
- @dataclass
129
- class SubpoolStatus:
130
- # for status, 0 is succeeded, 1 is failed
131
- status: int | None = None
132
- external_addresses: List[str] | None = None
133
- error: BaseException | None = None
134
- traceback: TracebackType | None = None
135
60
 
61
+ def _shm_put_object(seq: _ShmSeq, shm: shared_memory.SharedMemory, o: object):
62
+ serialized = pickle.dumps(o)
63
+ assert (
64
+ len(serialized) < _SUBPROCESS_SHM_SIZE - 8
65
+ ), f"Serialized object {o} is too long."
66
+ shm.buf[4:12] = struct.pack("<II", sys.hexversion, len(serialized))
67
+ shm.buf[12 : 12 + len(serialized)] = serialized
68
+ shm.buf[:4] = struct.pack("<I", seq)
136
69
 
137
- _PRE_SET_ENV_LOCK = asyncio.Lock()
138
70
 
139
-
140
- @contextlib.asynccontextmanager
141
- async def _pre_set_env_in_main(env: dict[str, str]):
142
- # Normally, `env` is set in sub pool,
143
- # but something may have happened during initialization,
144
- # e.g. CUDA_VISIBLE_DEVICES is too late to set when some actions may have polluted cuda
145
- # we have to set environ before new process started
146
- # enable this only when XOSCAR_PRE_SET_ENV=1
147
- enable_pre_set_env = bool(int(os.getenv("XOSCAR_PRE_SET_ENV", 0)))
148
- if not enable_pre_set_env or not env:
149
- yield
71
+ def _shm_get_object(seq: _ShmSeq, shm: shared_memory.SharedMemory):
72
+ recv_seq = struct.unpack("<I", shm.buf[:4])[0]
73
+ if recv_seq != seq:
150
74
  return
151
-
152
- global_environ = os.environ.copy()
153
- async with _PRE_SET_ENV_LOCK:
154
- try:
155
- logger.debug("Updating environment variables in main: %s", env)
156
- os.environ.update(env)
157
- yield
158
- finally:
159
- os.environ = global_environ # type: ignore
75
+ python_version_hex, size = struct.unpack("<II", shm.buf[4:12])
76
+ if python_version_hex != sys.hexversion:
77
+ raise RuntimeError(
78
+ f"Python version mismatch, sender: {python_version_hex}, receiver: {sys.hexversion}"
79
+ )
80
+ return pickle.loads(shm.buf[12 : 12 + size])
160
81
 
161
82
 
162
83
  @_register_message_handler
@@ -222,52 +143,21 @@ class MainActorPool(MainActorPoolBase):
222
143
  cls,
223
144
  actor_pool_config: ActorPoolConfig,
224
145
  process_index: int,
225
- start_method: str | None = None,
146
+ start_python: str | None = None,
226
147
  ):
227
- def start_pool_in_process():
228
- ctx = multiprocessing.get_context(method=start_method)
229
- status_queue = ctx.Queue()
230
- main_pool_pid = os.getpid()
231
-
232
- with _suspend_init_main():
233
- process = ctx.Process(
234
- target=cls._start_sub_pool,
235
- args=(
236
- actor_pool_config,
237
- process_index,
238
- status_queue,
239
- main_pool_pid,
240
- ),
241
- name=f"IndigenActorPool{process_index}",
242
- )
243
- process.start()
244
-
245
- # wait for sub actor pool to finish starting
246
- process_status = status_queue.get()
247
- return process, process_status
248
-
249
- _patch_spawn_get_preparation_data()
250
- loop = asyncio.get_running_loop()
251
- async with _pre_set_env_in_main(
252
- actor_pool_config.get_pool_config(process_index)["env"]
253
- ):
254
- with futures.ThreadPoolExecutor(1) as executor:
255
- create_pool_task = loop.run_in_executor(executor, start_pool_in_process)
256
- return await create_pool_task
148
+ return await cls._create_sub_pool_from_parent(
149
+ actor_pool_config, process_index, start_python
150
+ )
257
151
 
258
152
  @classmethod
259
153
  async def wait_sub_pools_ready(cls, create_pool_tasks: List[asyncio.Task]):
260
- processes: list[multiprocessing.Process] = []
154
+ processes: list[asyncio.subprocess.Process] = []
261
155
  ext_addresses = []
262
156
  error = None
263
157
  for task in create_pool_tasks:
264
- process, status = await task
158
+ process, address = await task
265
159
  processes.append(process)
266
- if status.status == 1:
267
- # start sub pool failed
268
- error = status.error.with_traceback(status.traceback)
269
- else:
270
- ext_addresses.append(status.external_addresses)
160
+ ext_addresses.append(address)
271
161
  if error:
272
162
  for p in processes:
273
163
  # error happens, kill all subprocesses
@@ -276,84 +166,158 @@ class MainActorPool(MainActorPoolBase):
276
166
  return processes, ext_addresses
277
167
 
278
168
  @classmethod
279
- def _start_sub_pool(
169
+ def _start_sub_pool_in_child(
280
170
  cls,
281
- actor_config: ActorPoolConfig,
282
- process_index: int,
283
- status_queue: multiprocessing.Queue,
284
- main_pool_pid: int,
171
+ shm_name: str,
285
172
  ):
286
173
  ensure_coverage()
287
174
 
288
- # make sure enough randomness for every sub pool
289
- random.seed(uuid.uuid1().bytes)
290
- reset_id_random_seed()
291
-
292
- conf = actor_config.get_pool_config(process_index)
293
- suspend_sigint = conf["suspend_sigint"]
294
- if suspend_sigint:
295
- signal.signal(signal.SIGINT, lambda *_: None)
296
-
297
- logging_conf = conf["logging_conf"] or {}
298
- if isinstance(logging_conf, configparser.RawConfigParser):
299
- logging.config.fileConfig(logging_conf)
300
- elif logging_conf.get("dict"):
301
- logging.config.dictConfig(logging_conf["dict"])
302
- elif logging_conf.get("file"):
303
- logging.config.fileConfig(logging_conf["file"])
304
- elif logging_conf.get("level"):
305
- logging.getLogger("__main__").setLevel(logging_conf["level"])
306
- logging.getLogger("xoscar").setLevel(logging_conf["level"])
307
- if logging_conf.get("format"):
308
- logging.basicConfig(format=logging_conf["format"])
309
-
310
- use_uvloop = conf["use_uvloop"]
311
- if use_uvloop:
312
- import uvloop
313
-
314
- asyncio.set_event_loop(uvloop.new_event_loop())
315
- else:
316
- asyncio.set_event_loop(asyncio.new_event_loop())
175
+ shm = shared_memory.SharedMemory(shm_name, track=False)
176
+ try:
177
+ config = _shm_get_object(_ShmSeq.INIT_PARAMS, shm)
178
+ actor_config = config["actor_pool_config"]
179
+ process_index = config["process_index"]
180
+ main_pool_pid = config["main_pool_pid"]
181
+
182
+ def _check_ppid():
183
+ while True:
184
+ try:
185
+ # We can't simply check if the os.getppid() equals with main_pool_pid,
186
+ # as the double fork may result in a new process as the parent.
187
+ psutil.Process(main_pool_pid)
188
+ except psutil.NoSuchProcess:
189
+ logger.info("Exit due to main pool %s exit.", main_pool_pid)
190
+ os._exit(0)
191
+ except Exception as e:
192
+ logger.exception("Check ppid failed: %s", e)
193
+ time.sleep(10)
194
+
195
+ t = threading.Thread(target=_check_ppid, daemon=True)
196
+ t.start()
197
+
198
+ # make sure enough randomness for every sub pool
199
+ random.seed(uuid.uuid1().bytes)
200
+ reset_id_random_seed()
201
+
202
+ conf = actor_config.get_pool_config(process_index)
203
+ suspend_sigint = conf["suspend_sigint"]
204
+ if suspend_sigint:
205
+ signal.signal(signal.SIGINT, lambda *_: None)
206
+
207
+ logging_conf = conf["logging_conf"] or {}
208
+ if isinstance(logging_conf, configparser.RawConfigParser):
209
+ logging.config.fileConfig(logging_conf)
210
+ elif logging_conf.get("dict"):
211
+ logging.config.dictConfig(logging_conf["dict"])
212
+ elif logging_conf.get("file"):
213
+ logging.config.fileConfig(logging_conf["file"])
214
+ elif logging_conf.get("level"):
215
+ logging.getLogger("__main__").setLevel(logging_conf["level"])
216
+ logging.getLogger("xoscar").setLevel(logging_conf["level"])
217
+ if logging_conf.get("format"):
218
+ logging.basicConfig(format=logging_conf["format"])
219
+
220
+ use_uvloop = conf["use_uvloop"]
221
+ if use_uvloop:
222
+ import uvloop
223
+
224
+ asyncio.set_event_loop(uvloop.new_event_loop())
225
+ else:
226
+ asyncio.set_event_loop(asyncio.new_event_loop())
317
227
 
318
- coro = cls._create_sub_pool(
319
- actor_config, process_index, status_queue, main_pool_pid
320
- )
321
- asyncio.run(coro)
228
+ coro = cls._create_sub_pool(actor_config, process_index, main_pool_pid, shm)
229
+ asyncio.run(coro)
230
+ finally:
231
+ shm.close()
322
232
 
323
233
  @classmethod
324
234
  async def _create_sub_pool(
325
235
  cls,
326
236
  actor_config: ActorPoolConfig,
327
237
  process_index: int,
328
- status_queue: multiprocessing.Queue,
329
238
  main_pool_pid: int,
239
+ shm: shared_memory.SharedMemory,
240
+ ):
241
+ cur_pool_config = actor_config.get_pool_config(process_index)
242
+ env = cur_pool_config["env"]
243
+ if env:
244
+ os.environ.update(env)
245
+ pool = await SubActorPool.create(
246
+ {
247
+ "actor_pool_config": actor_config,
248
+ "process_index": process_index,
249
+ "main_pool_pid": main_pool_pid,
250
+ }
251
+ )
252
+ await pool.start()
253
+ _shm_put_object(_ShmSeq.INIT_RESULT, shm, cur_pool_config["external_address"])
254
+ await pool.join()
255
+
256
+ @staticmethod
257
+ async def _create_sub_pool_from_parent(
258
+ actor_pool_config: ActorPoolConfig,
259
+ process_index: int,
260
+ start_python: str | None = None,
330
261
  ):
331
- process_status = None
262
+ # We check the Python version in _shm_get_object to make it faster,
263
+ # as in most cases the Python versions are the same.
264
+ if start_python is None:
265
+ start_python = sys.executable
266
+
267
+ external_addresses: List | None = None
268
+ shm = shared_memory.SharedMemory(
269
+ create=True, size=_SUBPROCESS_SHM_SIZE, track=False
270
+ )
332
271
  try:
333
- cur_pool_config = actor_config.get_pool_config(process_index)
334
- env = cur_pool_config["env"]
335
- if env:
336
- os.environ.update(env)
337
- pool = await SubActorPool.create(
272
+ _shm_put_object(
273
+ _ShmSeq.INIT_PARAMS,
274
+ shm,
338
275
  {
339
- "actor_pool_config": actor_config,
276
+ "actor_pool_config": actor_pool_config,
340
277
  "process_index": process_index,
341
- "main_pool_pid": main_pool_pid,
342
- }
278
+ "main_pool_pid": os.getpid(),
279
+ },
280
+ )
281
+ process = await create_subprocess_exec(
282
+ start_python,
283
+ "-m",
284
+ "xoscar.backends.indigen",
285
+ "start_sub_pool",
286
+ "-sn",
287
+ shm.name,
343
288
  )
344
- external_addresses = cur_pool_config["external_address"]
345
- process_status = SubpoolStatus(
346
- status=0, external_addresses=external_addresses
289
+
290
+ def _get_external_addresses():
291
+ try:
292
+ nonlocal external_addresses
293
+ while (
294
+ shm
295
+ and shm.buf is not None
296
+ and not (
297
+ external_addresses := _shm_get_object(
298
+ _ShmSeq.INIT_RESULT, shm
299
+ )
300
+ )
301
+ ):
302
+ time.sleep(0.1)
303
+ except asyncio.CancelledError:
304
+ pass
305
+
306
+ _, unfinished = await asyncio.wait(
307
+ [
308
+ asyncio.create_task(process.wait()),
309
+ asyncio.create_task(asyncio.to_thread(_get_external_addresses)),
310
+ ],
311
+ return_when=asyncio.FIRST_COMPLETED,
347
312
  )
348
- await pool.start()
349
- except: # noqa: E722 # nosec # pylint: disable=bare-except
350
- _, error, tb = sys.exc_info()
351
- process_status = SubpoolStatus(status=1, error=error, traceback=tb)
352
- raise
313
+ for t in unfinished:
314
+ t.cancel()
353
315
  finally:
354
- status_queue.put(process_status)
355
- status_queue.cancel_join_thread()
356
- await pool.join()
316
+ shm.close()
317
+ shm.unlink()
318
+ if external_addresses is None:
319
+ raise OSError("Start sub pool failed.")
320
+ return process, external_addresses
357
321
 
358
322
  async def append_sub_pool(
359
323
  self,
@@ -365,7 +329,7 @@ class MainActorPool(MainActorPoolBase):
365
329
  suspend_sigint: bool | None = None,
366
330
  use_uvloop: bool | None = None,
367
331
  logging_conf: dict | None = None,
368
- start_method: str | None = None,
332
+ start_python: str | None = None,
369
333
  kwargs: dict | None = None,
370
334
  ):
371
335
  # external_address has port 0, subprocess will bind random port.
@@ -404,33 +368,12 @@ class MainActorPool(MainActorPoolBase):
404
368
  kwargs,
405
369
  )
406
370
 
407
- def start_pool_in_process():
408
- ctx = multiprocessing.get_context(method=start_method)
409
- status_queue = ctx.Queue()
410
- main_pool_pid = os.getpid()
411
-
412
- with _suspend_init_main():
413
- process = ctx.Process(
414
- target=self._start_sub_pool,
415
- args=(self._config, process_index, status_queue, main_pool_pid),
416
- name=f"IndigenActorPool{process_index}",
417
- )
418
- process.start()
419
-
420
- # wait for sub actor pool to finish starting
421
- process_status = status_queue.get()
422
- return process, process_status
423
-
424
- loop = asyncio.get_running_loop()
425
- async with _pre_set_env_in_main(env): # type: ignore
426
- with futures.ThreadPoolExecutor(1) as executor:
427
- create_pool_task = loop.run_in_executor(executor, start_pool_in_process)
428
- process, process_status = await create_pool_task
429
-
430
- self._config.reset_pool_external_address(
431
- process_index, process_status.external_addresses[0]
371
+ process, external_addresses = await self._create_sub_pool_from_parent(
372
+ self._config, process_index, start_python
432
373
  )
433
- self.attach_sub_process(process_status.external_addresses[0], process)
374
+
375
+ self._config.reset_pool_external_address(process_index, external_addresses[0])
376
+ self.attach_sub_process(external_addresses[0], process)
434
377
 
435
378
  control_message = ControlMessage(
436
379
  message_id=new_message_id(),
@@ -440,16 +383,16 @@ class MainActorPool(MainActorPoolBase):
440
383
  )
441
384
  await self.handle_control_command(control_message)
442
385
  # The actual port will return in process_status.
443
- return process_status.external_addresses[0]
386
+ return external_addresses[0]
444
387
 
445
388
  async def remove_sub_pool(
446
389
  self, external_address: str, timeout: float | None = None, force: bool = False
447
390
  ):
448
391
  process = self.sub_processes[external_address]
449
392
  process_index = self._config.get_process_index(external_address)
450
- await self.stop_sub_pool(external_address, process, timeout, force)
451
393
  del self.sub_processes[external_address]
452
394
  self._config.remove_pool_config(process_index)
395
+ await self.stop_sub_pool(external_address, process, timeout, force)
453
396
 
454
397
  control_message = ControlMessage(
455
398
  message_id=new_message_id(),
@@ -460,42 +403,35 @@ class MainActorPool(MainActorPoolBase):
460
403
  await self.handle_control_command(control_message)
461
404
 
462
405
  async def kill_sub_pool(
463
- self, process: multiprocessing.Process, force: bool = False
406
+ self, process: asyncio.subprocess.Process, force: bool = False
464
407
  ):
408
+ try:
409
+ p = psutil.Process(process.pid)
410
+ except psutil.NoSuchProcess:
411
+ return
412
+
465
413
  if not force: # pragma: no cover
466
- # must shutdown gracefully, or subprocess created by model will not exit
467
- if not _is_windows:
468
- try:
469
- os.kill(process.pid, signal.SIGINT) # type: ignore
470
- except OSError: # pragma: no cover
471
- pass
472
- process.terminate() # SIGTERM
473
- wait_pool = futures.ThreadPoolExecutor(1)
414
+ p.terminate()
474
415
  try:
475
- loop = asyncio.get_running_loop()
476
- await loop.run_in_executor(wait_pool, process.join, 3)
477
- finally:
478
- wait_pool.shutdown(False)
479
- process.kill() # SIGKILL
480
- await asyncio.to_thread(process.join, 5)
481
-
482
- async def is_sub_pool_alive(self, process: multiprocessing.Process):
483
- try:
484
- return await asyncio.to_thread(process.is_alive)
485
- except RuntimeError as ex: # pragma: no cover
486
- if "cannot schedule new futures" not in str(ex):
487
- # when atexit is triggered, the default pool might be shutdown
488
- # and to_thread will fail
489
- raise
490
- return process.is_alive()
416
+ p.wait(5)
417
+ except psutil.TimeoutExpired:
418
+ pass
419
+
420
+ while p.is_running():
421
+ p.kill()
422
+ if not p.is_running():
423
+ return
424
+ logger.info("Sub pool can't be killed: %s", p)
425
+ time.sleep(0.1)
426
+
427
+ async def is_sub_pool_alive(self, process: asyncio.subprocess.Process):
428
+ return process.returncode is None
491
429
 
492
430
  async def recover_sub_pool(self, address: str):
493
431
  process_index = self._config.get_process_index(address)
494
432
  # process dead, restart it
495
433
  # remember always use spawn to recover sub pool
496
- task = asyncio.create_task(
497
- self.start_sub_pool(self._config, process_index, "spawn")
498
- )
434
+ task = asyncio.create_task(self.start_sub_pool(self._config, process_index))
499
435
  self.sub_processes[address] = (await self.wait_sub_pools_ready([task]))[0][0]
500
436
 
501
437
  if self._auto_recover == "actor":