xoscar 0.9.0__cp312-cp312-macosx_10_13_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xoscar/__init__.py +61 -0
- xoscar/_utils.cpython-312-darwin.so +0 -0
- xoscar/_utils.pxd +36 -0
- xoscar/_utils.pyx +246 -0
- xoscar/_version.py +693 -0
- xoscar/aio/__init__.py +16 -0
- xoscar/aio/base.py +86 -0
- xoscar/aio/file.py +59 -0
- xoscar/aio/lru.py +228 -0
- xoscar/aio/parallelism.py +39 -0
- xoscar/api.py +527 -0
- xoscar/backend.py +67 -0
- xoscar/backends/__init__.py +14 -0
- xoscar/backends/allocate_strategy.py +160 -0
- xoscar/backends/communication/__init__.py +30 -0
- xoscar/backends/communication/base.py +315 -0
- xoscar/backends/communication/core.py +69 -0
- xoscar/backends/communication/dummy.py +253 -0
- xoscar/backends/communication/errors.py +20 -0
- xoscar/backends/communication/socket.py +444 -0
- xoscar/backends/communication/ucx.py +538 -0
- xoscar/backends/communication/utils.py +97 -0
- xoscar/backends/config.py +157 -0
- xoscar/backends/context.py +437 -0
- xoscar/backends/core.py +352 -0
- xoscar/backends/indigen/__init__.py +16 -0
- xoscar/backends/indigen/__main__.py +19 -0
- xoscar/backends/indigen/backend.py +51 -0
- xoscar/backends/indigen/driver.py +26 -0
- xoscar/backends/indigen/fate_sharing.py +221 -0
- xoscar/backends/indigen/pool.py +515 -0
- xoscar/backends/indigen/shared_memory.py +548 -0
- xoscar/backends/message.cpython-312-darwin.so +0 -0
- xoscar/backends/message.pyi +255 -0
- xoscar/backends/message.pyx +646 -0
- xoscar/backends/pool.py +1630 -0
- xoscar/backends/router.py +285 -0
- xoscar/backends/test/__init__.py +16 -0
- xoscar/backends/test/backend.py +38 -0
- xoscar/backends/test/pool.py +233 -0
- xoscar/batch.py +256 -0
- xoscar/collective/__init__.py +27 -0
- xoscar/collective/backend/__init__.py +13 -0
- xoscar/collective/backend/nccl_backend.py +160 -0
- xoscar/collective/common.py +102 -0
- xoscar/collective/core.py +737 -0
- xoscar/collective/process_group.py +687 -0
- xoscar/collective/utils.py +41 -0
- xoscar/collective/xoscar_pygloo.cpython-312-darwin.so +0 -0
- xoscar/collective/xoscar_pygloo.pyi +239 -0
- xoscar/constants.py +23 -0
- xoscar/context.cpython-312-darwin.so +0 -0
- xoscar/context.pxd +21 -0
- xoscar/context.pyx +368 -0
- xoscar/core.cpython-312-darwin.so +0 -0
- xoscar/core.pxd +51 -0
- xoscar/core.pyx +664 -0
- xoscar/debug.py +188 -0
- xoscar/driver.py +42 -0
- xoscar/errors.py +63 -0
- xoscar/libcpp.pxd +31 -0
- xoscar/metrics/__init__.py +21 -0
- xoscar/metrics/api.py +288 -0
- xoscar/metrics/backends/__init__.py +13 -0
- xoscar/metrics/backends/console/__init__.py +13 -0
- xoscar/metrics/backends/console/console_metric.py +82 -0
- xoscar/metrics/backends/metric.py +149 -0
- xoscar/metrics/backends/prometheus/__init__.py +13 -0
- xoscar/metrics/backends/prometheus/prometheus_metric.py +70 -0
- xoscar/nvutils.py +717 -0
- xoscar/profiling.py +260 -0
- xoscar/serialization/__init__.py +20 -0
- xoscar/serialization/aio.py +141 -0
- xoscar/serialization/core.cpython-312-darwin.so +0 -0
- xoscar/serialization/core.pxd +28 -0
- xoscar/serialization/core.pyi +57 -0
- xoscar/serialization/core.pyx +944 -0
- xoscar/serialization/cuda.py +111 -0
- xoscar/serialization/exception.py +48 -0
- xoscar/serialization/mlx.py +67 -0
- xoscar/serialization/numpy.py +82 -0
- xoscar/serialization/pyfury.py +37 -0
- xoscar/serialization/scipy.py +72 -0
- xoscar/serialization/torch.py +180 -0
- xoscar/utils.py +522 -0
- xoscar/virtualenv/__init__.py +34 -0
- xoscar/virtualenv/core.py +268 -0
- xoscar/virtualenv/platform.py +56 -0
- xoscar/virtualenv/utils.py +100 -0
- xoscar/virtualenv/uv.py +321 -0
- xoscar-0.9.0.dist-info/METADATA +230 -0
- xoscar-0.9.0.dist-info/RECORD +94 -0
- xoscar-0.9.0.dist-info/WHEEL +6 -0
- xoscar-0.9.0.dist-info/top_level.txt +2 -0
xoscar/backends/pool.py
ADDED
|
@@ -0,0 +1,1630 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import asyncio
|
|
19
|
+
import asyncio.subprocess
|
|
20
|
+
import concurrent.futures as futures
|
|
21
|
+
import contextlib
|
|
22
|
+
import itertools
|
|
23
|
+
import logging
|
|
24
|
+
import multiprocessing
|
|
25
|
+
import os
|
|
26
|
+
import threading
|
|
27
|
+
import traceback
|
|
28
|
+
from abc import ABC, ABCMeta, abstractmethod
|
|
29
|
+
from typing import Any, Callable, Coroutine, Optional, Type, TypeVar
|
|
30
|
+
|
|
31
|
+
import psutil
|
|
32
|
+
|
|
33
|
+
from .._utils import TypeDispatcher, create_actor_ref, to_binary
|
|
34
|
+
from ..api import Actor
|
|
35
|
+
from ..core import ActorRef, BufferRef, FileObjectRef, register_local_pool
|
|
36
|
+
from ..debug import debug_async_timeout, record_message_trace
|
|
37
|
+
from ..errors import (
|
|
38
|
+
ActorAlreadyExist,
|
|
39
|
+
ActorNotExist,
|
|
40
|
+
CannotCancelTask,
|
|
41
|
+
SendMessageFailed,
|
|
42
|
+
ServerClosed,
|
|
43
|
+
)
|
|
44
|
+
from ..metrics import init_metrics
|
|
45
|
+
from ..utils import implements, is_zero_ip, register_asyncio_task_timeout_detector
|
|
46
|
+
from .allocate_strategy import AddressSpecified, allocated_type
|
|
47
|
+
from .communication import (
|
|
48
|
+
Channel,
|
|
49
|
+
Server,
|
|
50
|
+
UCXChannel,
|
|
51
|
+
gen_local_address,
|
|
52
|
+
get_server_type,
|
|
53
|
+
)
|
|
54
|
+
from .communication.errors import ChannelClosed
|
|
55
|
+
from .config import ActorPoolConfig
|
|
56
|
+
from .core import ActorCaller, ResultMessageType
|
|
57
|
+
from .message import (
|
|
58
|
+
DEFAULT_PROTOCOL,
|
|
59
|
+
ActorRefMessage,
|
|
60
|
+
CancelMessage,
|
|
61
|
+
ControlMessage,
|
|
62
|
+
ControlMessageType,
|
|
63
|
+
CreateActorMessage,
|
|
64
|
+
DestroyActorMessage,
|
|
65
|
+
ErrorMessage,
|
|
66
|
+
ForwardMessage,
|
|
67
|
+
HasActorMessage,
|
|
68
|
+
MessageType,
|
|
69
|
+
ResultMessage,
|
|
70
|
+
SendMessage,
|
|
71
|
+
TellMessage,
|
|
72
|
+
_MessageBase,
|
|
73
|
+
new_message_id,
|
|
74
|
+
)
|
|
75
|
+
from .router import Router
|
|
76
|
+
|
|
77
|
+
logger = logging.getLogger(__name__)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@contextlib.contextmanager
|
|
81
|
+
def _disable_log_temporally():
|
|
82
|
+
if os.getenv("CUDA_VISIBLE_DEVICES") == "-1":
|
|
83
|
+
# disable logging when CUDA_VISIBLE_DEVICES == -1
|
|
84
|
+
# many logging comes from ptxcompiler may distract users
|
|
85
|
+
try:
|
|
86
|
+
logging.disable(level=logging.ERROR)
|
|
87
|
+
yield
|
|
88
|
+
finally:
|
|
89
|
+
logging.disable(level=logging.NOTSET)
|
|
90
|
+
else:
|
|
91
|
+
yield
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class _ErrorProcessor:
|
|
95
|
+
def __init__(self, address: str, message_id: bytes, protocol):
|
|
96
|
+
self._address = address
|
|
97
|
+
self._message_id = message_id
|
|
98
|
+
self._protocol = protocol
|
|
99
|
+
self.result = None
|
|
100
|
+
|
|
101
|
+
def __enter__(self):
|
|
102
|
+
return self
|
|
103
|
+
|
|
104
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
105
|
+
if self.result is None:
|
|
106
|
+
self.result = ErrorMessage(
|
|
107
|
+
self._message_id,
|
|
108
|
+
self._address,
|
|
109
|
+
os.getpid(),
|
|
110
|
+
exc_type,
|
|
111
|
+
exc_val,
|
|
112
|
+
exc_tb,
|
|
113
|
+
protocol=self._protocol,
|
|
114
|
+
)
|
|
115
|
+
return True
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _register_message_handler(pool_type: Type["AbstractActorPool"]):
|
|
119
|
+
pool_type._message_handler = dict()
|
|
120
|
+
for message_type, handler in [
|
|
121
|
+
(MessageType.create_actor, pool_type.create_actor),
|
|
122
|
+
(MessageType.destroy_actor, pool_type.destroy_actor),
|
|
123
|
+
(MessageType.has_actor, pool_type.has_actor),
|
|
124
|
+
(MessageType.actor_ref, pool_type.actor_ref),
|
|
125
|
+
(MessageType.send, pool_type.send),
|
|
126
|
+
(MessageType.tell, pool_type.tell),
|
|
127
|
+
(MessageType.cancel, pool_type.cancel),
|
|
128
|
+
(MessageType.forward, pool_type.forward),
|
|
129
|
+
(MessageType.control, pool_type.handle_control_command),
|
|
130
|
+
(MessageType.copy_to_buffers, pool_type.handle_copy_to_buffers_message),
|
|
131
|
+
(MessageType.copy_to_fileobjs, pool_type.handle_copy_to_fileobjs_message),
|
|
132
|
+
]:
|
|
133
|
+
pool_type._message_handler[message_type] = handler # type: ignore
|
|
134
|
+
return pool_type
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class AbstractActorPool(ABC):
|
|
138
|
+
__slots__ = (
|
|
139
|
+
"process_index",
|
|
140
|
+
"label",
|
|
141
|
+
"external_address",
|
|
142
|
+
"internal_address",
|
|
143
|
+
"env",
|
|
144
|
+
"_servers",
|
|
145
|
+
"_router",
|
|
146
|
+
"_config",
|
|
147
|
+
"_stopped",
|
|
148
|
+
"_actors",
|
|
149
|
+
"_caller",
|
|
150
|
+
"_process_messages",
|
|
151
|
+
"_asyncio_task_timeout_detector_task",
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
_message_handler: dict[MessageType, Callable]
|
|
155
|
+
_process_messages: dict[bytes, asyncio.Future | asyncio.Task | None]
|
|
156
|
+
|
|
157
|
+
def __init__(
|
|
158
|
+
self,
|
|
159
|
+
process_index: int,
|
|
160
|
+
label: str,
|
|
161
|
+
external_address: str,
|
|
162
|
+
internal_address: str,
|
|
163
|
+
env: dict,
|
|
164
|
+
router: Router,
|
|
165
|
+
config: ActorPoolConfig,
|
|
166
|
+
servers: list[Server],
|
|
167
|
+
):
|
|
168
|
+
# register local pool for local actor lookup.
|
|
169
|
+
# The pool is weakrefed, so we don't need to unregister it.
|
|
170
|
+
if not is_zero_ip(external_address):
|
|
171
|
+
# Only register_local_pool when we listen on non-zero ip (because all-zero ip is wildcard address),
|
|
172
|
+
# avoid mistaken with another remote service listen on non-zero ip with the same port.
|
|
173
|
+
register_local_pool(external_address, self)
|
|
174
|
+
self.process_index = process_index
|
|
175
|
+
self.label = label
|
|
176
|
+
self.external_address = external_address
|
|
177
|
+
self.internal_address = internal_address
|
|
178
|
+
self.env = env
|
|
179
|
+
self._router = router
|
|
180
|
+
self._config = config
|
|
181
|
+
self._servers = servers
|
|
182
|
+
|
|
183
|
+
self._stopped = asyncio.Event()
|
|
184
|
+
|
|
185
|
+
# states
|
|
186
|
+
# actor id -> actor
|
|
187
|
+
self._actors: dict[bytes, Actor] = dict()
|
|
188
|
+
# message id -> future
|
|
189
|
+
self._process_messages = dict()
|
|
190
|
+
|
|
191
|
+
# manage async actor callers
|
|
192
|
+
self._caller = ActorCaller()
|
|
193
|
+
self._asyncio_task_timeout_detector_task = (
|
|
194
|
+
register_asyncio_task_timeout_detector()
|
|
195
|
+
)
|
|
196
|
+
# init metrics
|
|
197
|
+
metric_configs = self._config.get_metric_configs()
|
|
198
|
+
metric_backend = metric_configs.get("backend")
|
|
199
|
+
init_metrics(metric_backend, config=metric_configs.get(metric_backend))
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def router(self):
|
|
203
|
+
return self._router
|
|
204
|
+
|
|
205
|
+
@abstractmethod
|
|
206
|
+
async def create_actor(self, message: CreateActorMessage) -> ResultMessageType:
|
|
207
|
+
"""
|
|
208
|
+
Create an actor.
|
|
209
|
+
|
|
210
|
+
Parameters
|
|
211
|
+
----------
|
|
212
|
+
message: CreateActorMessage
|
|
213
|
+
message to create an actor.
|
|
214
|
+
|
|
215
|
+
Returns
|
|
216
|
+
-------
|
|
217
|
+
result_message
|
|
218
|
+
result or error message.
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
@abstractmethod
|
|
222
|
+
async def has_actor(self, message: HasActorMessage) -> ResultMessage:
|
|
223
|
+
"""
|
|
224
|
+
Check if an actor exists or not.
|
|
225
|
+
|
|
226
|
+
Parameters
|
|
227
|
+
----------
|
|
228
|
+
message: HasActorMessage
|
|
229
|
+
message
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
result_message
|
|
234
|
+
result message contains if an actor exists or not.
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
@abstractmethod
|
|
238
|
+
async def destroy_actor(self, message: DestroyActorMessage) -> ResultMessageType:
|
|
239
|
+
"""
|
|
240
|
+
Destroy an actor.
|
|
241
|
+
|
|
242
|
+
Parameters
|
|
243
|
+
----------
|
|
244
|
+
message: DestroyActorMessage
|
|
245
|
+
message to destroy an actor.
|
|
246
|
+
|
|
247
|
+
Returns
|
|
248
|
+
-------
|
|
249
|
+
result_message
|
|
250
|
+
result or error message.
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
@abstractmethod
|
|
254
|
+
async def actor_ref(self, message: ActorRefMessage) -> ResultMessageType:
|
|
255
|
+
"""
|
|
256
|
+
Get an actor's ref.
|
|
257
|
+
|
|
258
|
+
Parameters
|
|
259
|
+
----------
|
|
260
|
+
message: ActorRefMessage
|
|
261
|
+
message to get an actor's ref.
|
|
262
|
+
|
|
263
|
+
Returns
|
|
264
|
+
-------
|
|
265
|
+
result_message
|
|
266
|
+
result or error message.
|
|
267
|
+
"""
|
|
268
|
+
|
|
269
|
+
@abstractmethod
|
|
270
|
+
async def send(self, message: SendMessage) -> ResultMessageType:
|
|
271
|
+
"""
|
|
272
|
+
Send a message to some actor.
|
|
273
|
+
|
|
274
|
+
Parameters
|
|
275
|
+
----------
|
|
276
|
+
message: SendMessage
|
|
277
|
+
Message to send.
|
|
278
|
+
|
|
279
|
+
Returns
|
|
280
|
+
-------
|
|
281
|
+
result_message
|
|
282
|
+
result or error message.
|
|
283
|
+
"""
|
|
284
|
+
|
|
285
|
+
@abstractmethod
|
|
286
|
+
async def tell(self, message: TellMessage) -> ResultMessageType:
|
|
287
|
+
"""
|
|
288
|
+
Tell message to some actor.
|
|
289
|
+
|
|
290
|
+
Parameters
|
|
291
|
+
----------
|
|
292
|
+
message: TellMessage
|
|
293
|
+
Message to tell.
|
|
294
|
+
|
|
295
|
+
Returns
|
|
296
|
+
-------
|
|
297
|
+
result_message
|
|
298
|
+
result or error message.
|
|
299
|
+
"""
|
|
300
|
+
|
|
301
|
+
@abstractmethod
|
|
302
|
+
async def cancel(self, message: CancelMessage) -> ResultMessageType:
|
|
303
|
+
"""
|
|
304
|
+
Cancel message that sent
|
|
305
|
+
|
|
306
|
+
Parameters
|
|
307
|
+
----------
|
|
308
|
+
message: CancelMessage
|
|
309
|
+
Cancel message.
|
|
310
|
+
|
|
311
|
+
Returns
|
|
312
|
+
-------
|
|
313
|
+
result_message
|
|
314
|
+
result or error message
|
|
315
|
+
"""
|
|
316
|
+
|
|
317
|
+
async def forward(self, message: ForwardMessage) -> ResultMessageType:
|
|
318
|
+
"""
|
|
319
|
+
Forward message
|
|
320
|
+
|
|
321
|
+
Parameters
|
|
322
|
+
----------
|
|
323
|
+
message: ForwardMessage
|
|
324
|
+
Forward message.
|
|
325
|
+
|
|
326
|
+
Returns
|
|
327
|
+
-------
|
|
328
|
+
result_message
|
|
329
|
+
result or error message
|
|
330
|
+
"""
|
|
331
|
+
return await self.call(message.address, message.raw_message)
|
|
332
|
+
|
|
333
|
+
def _sync_pool_config(self, actor_pool_config: ActorPoolConfig):
|
|
334
|
+
self._config = actor_pool_config
|
|
335
|
+
# remove router from global one
|
|
336
|
+
global_router = Router.get_instance()
|
|
337
|
+
global_router.remove_router(self._router) # type: ignore
|
|
338
|
+
# update router
|
|
339
|
+
self._router.set_mapping(actor_pool_config.external_to_internal_address_map)
|
|
340
|
+
# update global router
|
|
341
|
+
global_router.add_router(self._router) # type: ignore
|
|
342
|
+
|
|
343
|
+
async def handle_control_command(
|
|
344
|
+
self, message: ControlMessage
|
|
345
|
+
) -> ResultMessageType:
|
|
346
|
+
"""
|
|
347
|
+
Handle control command.
|
|
348
|
+
|
|
349
|
+
Parameters
|
|
350
|
+
----------
|
|
351
|
+
message: ControlMessage
|
|
352
|
+
Control message.
|
|
353
|
+
|
|
354
|
+
Returns
|
|
355
|
+
-------
|
|
356
|
+
result_message
|
|
357
|
+
result or error message.
|
|
358
|
+
"""
|
|
359
|
+
with _ErrorProcessor(
|
|
360
|
+
self.external_address, message.message_id, protocol=message.protocol
|
|
361
|
+
) as processor:
|
|
362
|
+
content: bool | ActorPoolConfig = True
|
|
363
|
+
if message.control_message_type == ControlMessageType.stop:
|
|
364
|
+
await self.stop()
|
|
365
|
+
elif message.control_message_type == ControlMessageType.sync_config:
|
|
366
|
+
self._sync_pool_config(message.content)
|
|
367
|
+
elif message.control_message_type == ControlMessageType.get_config:
|
|
368
|
+
if message.content == "main_pool_address":
|
|
369
|
+
main_process_index = self._config.get_process_indexes()[0]
|
|
370
|
+
content = self._config.get_pool_config(main_process_index)[
|
|
371
|
+
"external_address"
|
|
372
|
+
][0]
|
|
373
|
+
else:
|
|
374
|
+
content = self._config
|
|
375
|
+
else: # pragma: no cover
|
|
376
|
+
raise TypeError(
|
|
377
|
+
f"Unable to handle control message "
|
|
378
|
+
f"with type {message.control_message_type}"
|
|
379
|
+
)
|
|
380
|
+
processor.result = ResultMessage(
|
|
381
|
+
message.message_id, content, protocol=message.protocol
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
return processor.result
|
|
385
|
+
|
|
386
|
+
async def _run_coro(self, message_id: bytes, coro: Coroutine):
|
|
387
|
+
self._process_messages[message_id] = asyncio.tasks.current_task()
|
|
388
|
+
try:
|
|
389
|
+
return await coro
|
|
390
|
+
finally:
|
|
391
|
+
self._process_messages.pop(message_id, None)
|
|
392
|
+
|
|
393
|
+
async def _send_channel(
|
|
394
|
+
self, result: _MessageBase, channel: Channel, resend_failure: bool = True
|
|
395
|
+
):
|
|
396
|
+
try:
|
|
397
|
+
await channel.send(result)
|
|
398
|
+
except (ChannelClosed, ConnectionResetError):
|
|
399
|
+
if not self._stopped.is_set() and not channel.closed:
|
|
400
|
+
raise
|
|
401
|
+
except Exception as ex:
|
|
402
|
+
logger.exception(
|
|
403
|
+
"Error when sending message %s from %s to %s",
|
|
404
|
+
result.message_id.hex(),
|
|
405
|
+
channel.local_address,
|
|
406
|
+
channel.dest_address,
|
|
407
|
+
)
|
|
408
|
+
if not resend_failure: # pragma: no cover
|
|
409
|
+
raise
|
|
410
|
+
|
|
411
|
+
with _ErrorProcessor(
|
|
412
|
+
self.external_address, result.message_id, result.protocol
|
|
413
|
+
) as processor:
|
|
414
|
+
error_msg = (
|
|
415
|
+
f"Error when sending message {result.message_id.hex()}. "
|
|
416
|
+
f"Caused by {ex!r}. "
|
|
417
|
+
)
|
|
418
|
+
if isinstance(result, ErrorMessage):
|
|
419
|
+
format_tb = "\n".join(traceback.format_tb(result.traceback))
|
|
420
|
+
error_msg += (
|
|
421
|
+
f"\nOriginal error: {result.error!r}"
|
|
422
|
+
f"Traceback: \n{format_tb}"
|
|
423
|
+
)
|
|
424
|
+
else:
|
|
425
|
+
error_msg += "See server logs for more details"
|
|
426
|
+
raise SendMessageFailed(error_msg) from None
|
|
427
|
+
await self._send_channel(processor.result, channel, resend_failure=False)
|
|
428
|
+
|
|
429
|
+
async def process_message(self, message: _MessageBase, channel: Channel):
|
|
430
|
+
handler = self._message_handler[message.message_type]
|
|
431
|
+
with _ErrorProcessor(
|
|
432
|
+
self.external_address, message.message_id, message.protocol
|
|
433
|
+
) as processor:
|
|
434
|
+
# use `%.500` to avoid print too long messages
|
|
435
|
+
with debug_async_timeout(
|
|
436
|
+
"process_message_timeout",
|
|
437
|
+
"Process message %.500s of channel %s timeout.",
|
|
438
|
+
message,
|
|
439
|
+
channel,
|
|
440
|
+
):
|
|
441
|
+
processor.result = await self._run_coro(
|
|
442
|
+
message.message_id, handler(self, message)
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
await self._send_channel(processor.result, channel)
|
|
446
|
+
|
|
447
|
+
async def call(self, dest_address: str, message: _MessageBase) -> ResultMessageType:
|
|
448
|
+
return await self._caller.call(self._router, dest_address, message) # type: ignore
|
|
449
|
+
|
|
450
|
+
@staticmethod
|
|
451
|
+
def _parse_config(config: dict, kw: dict) -> dict:
|
|
452
|
+
actor_pool_config: ActorPoolConfig = config.pop("actor_pool_config")
|
|
453
|
+
kw["config"] = actor_pool_config
|
|
454
|
+
kw["process_index"] = process_index = config.pop("process_index")
|
|
455
|
+
curr_pool_config = actor_pool_config.get_pool_config(process_index)
|
|
456
|
+
kw["label"] = curr_pool_config["label"]
|
|
457
|
+
external_addresses = curr_pool_config["external_address"]
|
|
458
|
+
kw["external_address"] = external_addresses[0]
|
|
459
|
+
kw["internal_address"] = curr_pool_config["internal_address"]
|
|
460
|
+
kw["router"] = Router(
|
|
461
|
+
external_addresses,
|
|
462
|
+
gen_local_address(process_index),
|
|
463
|
+
actor_pool_config.external_to_internal_address_map,
|
|
464
|
+
comm_config=actor_pool_config.get_comm_config(),
|
|
465
|
+
proxy_config=actor_pool_config.get_proxy_config(),
|
|
466
|
+
)
|
|
467
|
+
kw["env"] = curr_pool_config["env"]
|
|
468
|
+
|
|
469
|
+
if config: # pragma: no cover
|
|
470
|
+
raise TypeError(
|
|
471
|
+
f"Creating pool got unexpected " f'arguments: {",".join(config)}'
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
return kw
|
|
475
|
+
|
|
476
|
+
@classmethod
|
|
477
|
+
@abstractmethod
|
|
478
|
+
async def create(cls, config: dict) -> "AbstractActorPool":
|
|
479
|
+
"""
|
|
480
|
+
Create an actor pool.
|
|
481
|
+
|
|
482
|
+
Parameters
|
|
483
|
+
----------
|
|
484
|
+
config: dict
|
|
485
|
+
configurations.
|
|
486
|
+
|
|
487
|
+
Returns
|
|
488
|
+
-------
|
|
489
|
+
actor_pool:
|
|
490
|
+
Actor pool.
|
|
491
|
+
"""
|
|
492
|
+
|
|
493
|
+
async def start(self):
|
|
494
|
+
if self._stopped.is_set():
|
|
495
|
+
raise RuntimeError("pool has been stopped, cannot start again")
|
|
496
|
+
start_servers = [server.start() for server in self._servers]
|
|
497
|
+
await asyncio.gather(*start_servers)
|
|
498
|
+
|
|
499
|
+
async def join(self, timeout: float | None = None):
|
|
500
|
+
wait_stopped = asyncio.create_task(self._stopped.wait())
|
|
501
|
+
|
|
502
|
+
try:
|
|
503
|
+
await asyncio.wait_for(wait_stopped, timeout=timeout)
|
|
504
|
+
except (futures.TimeoutError, asyncio.TimeoutError): # pragma: no cover
|
|
505
|
+
wait_stopped.cancel()
|
|
506
|
+
|
|
507
|
+
async def stop(self):
|
|
508
|
+
try:
|
|
509
|
+
# clean global router
|
|
510
|
+
router = Router.get_instance()
|
|
511
|
+
if router is not None:
|
|
512
|
+
router.remove_router(self._router)
|
|
513
|
+
stop_tasks = []
|
|
514
|
+
# stop all servers
|
|
515
|
+
stop_tasks.extend([server.stop() for server in self._servers])
|
|
516
|
+
# stop all clients
|
|
517
|
+
stop_tasks.append(self._caller.stop())
|
|
518
|
+
await asyncio.gather(*stop_tasks)
|
|
519
|
+
|
|
520
|
+
self._servers = []
|
|
521
|
+
if self._asyncio_task_timeout_detector_task: # pragma: no cover
|
|
522
|
+
self._asyncio_task_timeout_detector_task.cancel()
|
|
523
|
+
finally:
|
|
524
|
+
self._stopped.set()
|
|
525
|
+
|
|
526
|
+
async def handle_copy_to_buffers_message(self, message) -> ResultMessage:
|
|
527
|
+
for addr, uid, start, _len, data in message.content:
|
|
528
|
+
buffer = BufferRef.get_buffer(BufferRef(addr, uid))
|
|
529
|
+
buffer[start : start + _len] = data
|
|
530
|
+
return ResultMessage(message_id=message.message_id, result=True)
|
|
531
|
+
|
|
532
|
+
async def handle_copy_to_fileobjs_message(self, message) -> ResultMessage:
|
|
533
|
+
for addr, uid, data in message.content:
|
|
534
|
+
file_obj = FileObjectRef.get_local_file_object(FileObjectRef(addr, uid))
|
|
535
|
+
await file_obj.write(data)
|
|
536
|
+
return ResultMessage(message_id=message.message_id, result=True)
|
|
537
|
+
|
|
538
|
+
@property
|
|
539
|
+
def stopped(self) -> bool:
|
|
540
|
+
return self._stopped.is_set()
|
|
541
|
+
|
|
542
|
+
async def _handle_ucx_meta_message(
|
|
543
|
+
self, message: _MessageBase, channel: Channel
|
|
544
|
+
) -> bool:
|
|
545
|
+
if (
|
|
546
|
+
isinstance(message, ControlMessage)
|
|
547
|
+
and message.message_type == MessageType.control
|
|
548
|
+
and message.control_message_type == ControlMessageType.switch_to_copy_to
|
|
549
|
+
and isinstance(channel, UCXChannel)
|
|
550
|
+
):
|
|
551
|
+
with _ErrorProcessor(
|
|
552
|
+
self.external_address, message.message_id, message.protocol
|
|
553
|
+
) as processor:
|
|
554
|
+
# use `%.500` to avoid print too long messages
|
|
555
|
+
with debug_async_timeout(
|
|
556
|
+
"process_message_timeout",
|
|
557
|
+
"Process message %.500s of channel %s timeout.",
|
|
558
|
+
message,
|
|
559
|
+
channel,
|
|
560
|
+
):
|
|
561
|
+
buffers = [
|
|
562
|
+
BufferRef.get_buffer(BufferRef(addr, uid))
|
|
563
|
+
for addr, uid in message.content
|
|
564
|
+
]
|
|
565
|
+
await channel.recv_buffers(buffers)
|
|
566
|
+
processor.result = ResultMessage(
|
|
567
|
+
message_id=message.message_id, result=True
|
|
568
|
+
)
|
|
569
|
+
asyncio.create_task(self._send_channel(processor.result, channel))
|
|
570
|
+
return True
|
|
571
|
+
return False
|
|
572
|
+
|
|
573
|
+
async def on_new_channel(self, channel: Channel):
|
|
574
|
+
try:
|
|
575
|
+
while not self._stopped.is_set():
|
|
576
|
+
try:
|
|
577
|
+
message = await channel.recv()
|
|
578
|
+
except (EOFError, ConnectionError, BrokenPipeError) as e:
|
|
579
|
+
logger.debug(f"pool: close connection due to {e}")
|
|
580
|
+
# no data to read, check channel
|
|
581
|
+
try:
|
|
582
|
+
await channel.close()
|
|
583
|
+
except (ConnectionError, EOFError):
|
|
584
|
+
# close failed, ignore
|
|
585
|
+
pass
|
|
586
|
+
return
|
|
587
|
+
if await self._handle_ucx_meta_message(message, channel):
|
|
588
|
+
continue
|
|
589
|
+
asyncio.create_task(self.process_message(message, channel))
|
|
590
|
+
# delete to release the reference of message
|
|
591
|
+
del message
|
|
592
|
+
await asyncio.sleep(0)
|
|
593
|
+
finally:
|
|
594
|
+
try:
|
|
595
|
+
await channel.close()
|
|
596
|
+
except: # noqa: E722 # nosec # pylint: disable=bare-except
|
|
597
|
+
# ignore all error if fail to close at last
|
|
598
|
+
pass
|
|
599
|
+
|
|
600
|
+
async def __aenter__(self):
|
|
601
|
+
await self.start()
|
|
602
|
+
return self
|
|
603
|
+
|
|
604
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
605
|
+
await self.stop()
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
class ActorPoolBase(AbstractActorPool, metaclass=ABCMeta):
|
|
609
|
+
__slots__ = ()
|
|
610
|
+
|
|
611
|
+
@implements(AbstractActorPool.create_actor)
|
|
612
|
+
async def create_actor(self, message: CreateActorMessage) -> ResultMessageType:
|
|
613
|
+
with _ErrorProcessor(
|
|
614
|
+
self.external_address, message.message_id, message.protocol
|
|
615
|
+
) as processor:
|
|
616
|
+
actor_id = message.actor_id
|
|
617
|
+
if actor_id in self._actors:
|
|
618
|
+
raise ActorAlreadyExist(
|
|
619
|
+
f"Actor {actor_id!r} already exist, cannot create"
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
actor = message.actor_cls(*message.args, **message.kwargs)
|
|
623
|
+
actor.uid = actor_id
|
|
624
|
+
actor.address = address = self.external_address
|
|
625
|
+
self._actors[actor_id] = actor
|
|
626
|
+
await self._run_coro(message.message_id, actor.__post_create__())
|
|
627
|
+
|
|
628
|
+
proxies = self._router.get_proxies(address)
|
|
629
|
+
result = ActorRef(address, actor_id, proxy_addresses=proxies)
|
|
630
|
+
# ensemble result message
|
|
631
|
+
processor.result = ResultMessage(
|
|
632
|
+
message.message_id, result, protocol=message.protocol
|
|
633
|
+
)
|
|
634
|
+
return processor.result
|
|
635
|
+
|
|
636
|
+
@implements(AbstractActorPool.has_actor)
|
|
637
|
+
async def has_actor(self, message: HasActorMessage) -> ResultMessage:
|
|
638
|
+
result = ResultMessage(
|
|
639
|
+
message.message_id,
|
|
640
|
+
message.actor_ref.uid in self._actors,
|
|
641
|
+
protocol=message.protocol,
|
|
642
|
+
)
|
|
643
|
+
return result
|
|
644
|
+
|
|
645
|
+
@implements(AbstractActorPool.destroy_actor)
|
|
646
|
+
async def destroy_actor(self, message: DestroyActorMessage) -> ResultMessageType:
|
|
647
|
+
with _ErrorProcessor(
|
|
648
|
+
self.external_address, message.message_id, message.protocol
|
|
649
|
+
) as processor:
|
|
650
|
+
actor_id = message.actor_ref.uid
|
|
651
|
+
try:
|
|
652
|
+
actor = self._actors[actor_id]
|
|
653
|
+
except KeyError:
|
|
654
|
+
raise ActorNotExist(f"Actor {actor_id} does not exist")
|
|
655
|
+
await self._run_coro(message.message_id, actor.__pre_destroy__())
|
|
656
|
+
del self._actors[actor_id]
|
|
657
|
+
|
|
658
|
+
processor.result = ResultMessage(
|
|
659
|
+
message.message_id, actor_id, protocol=message.protocol
|
|
660
|
+
)
|
|
661
|
+
return processor.result
|
|
662
|
+
|
|
663
|
+
@implements(AbstractActorPool.actor_ref)
|
|
664
|
+
async def actor_ref(self, message: ActorRefMessage) -> ResultMessageType:
|
|
665
|
+
with _ErrorProcessor(
|
|
666
|
+
self.external_address, message.message_id, message.protocol
|
|
667
|
+
) as processor:
|
|
668
|
+
actor_id = message.actor_ref.uid
|
|
669
|
+
if actor_id not in self._actors:
|
|
670
|
+
raise ActorNotExist(f"Actor {actor_id} does not exist")
|
|
671
|
+
proxies = self._router.get_proxies(self.external_address)
|
|
672
|
+
result = ResultMessage(
|
|
673
|
+
message.message_id,
|
|
674
|
+
ActorRef(self.external_address, actor_id, proxy_addresses=proxies),
|
|
675
|
+
protocol=message.protocol,
|
|
676
|
+
)
|
|
677
|
+
processor.result = result
|
|
678
|
+
return processor.result
|
|
679
|
+
|
|
680
|
+
@implements(AbstractActorPool.send)
|
|
681
|
+
async def send(self, message: SendMessage) -> ResultMessageType:
|
|
682
|
+
with _ErrorProcessor(
|
|
683
|
+
self.external_address, message.message_id, message.protocol
|
|
684
|
+
) as processor, record_message_trace(message):
|
|
685
|
+
actor_id = message.actor_ref.uid
|
|
686
|
+
if actor_id not in self._actors:
|
|
687
|
+
raise ActorNotExist(f"Actor {actor_id} does not exist")
|
|
688
|
+
coro = self._actors[actor_id].__on_receive__(message.content)
|
|
689
|
+
result = await self._run_coro(message.message_id, coro)
|
|
690
|
+
processor.result = ResultMessage(
|
|
691
|
+
message.message_id,
|
|
692
|
+
result,
|
|
693
|
+
protocol=message.protocol,
|
|
694
|
+
profiling_context=message.profiling_context,
|
|
695
|
+
)
|
|
696
|
+
return processor.result
|
|
697
|
+
|
|
698
|
+
@implements(AbstractActorPool.tell)
|
|
699
|
+
async def tell(self, message: TellMessage) -> ResultMessageType:
|
|
700
|
+
with _ErrorProcessor(
|
|
701
|
+
self.external_address, message.message_id, message.protocol
|
|
702
|
+
) as processor:
|
|
703
|
+
actor_id = message.actor_ref.uid
|
|
704
|
+
if actor_id not in self._actors: # pragma: no cover
|
|
705
|
+
raise ActorNotExist(f"Actor {actor_id} does not exist")
|
|
706
|
+
call = self._actors[actor_id].__on_receive__(message.content)
|
|
707
|
+
# asynchronously run, tell does not care about result
|
|
708
|
+
asyncio.create_task(call)
|
|
709
|
+
await asyncio.sleep(0)
|
|
710
|
+
processor.result = ResultMessage(
|
|
711
|
+
message.message_id,
|
|
712
|
+
None,
|
|
713
|
+
protocol=message.protocol,
|
|
714
|
+
profiling_context=message.profiling_context,
|
|
715
|
+
)
|
|
716
|
+
return processor.result
|
|
717
|
+
|
|
718
|
+
@implements(AbstractActorPool.cancel)
|
|
719
|
+
async def cancel(self, message: CancelMessage) -> ResultMessageType:
|
|
720
|
+
with _ErrorProcessor(
|
|
721
|
+
self.external_address, message.message_id, message.protocol
|
|
722
|
+
) as processor:
|
|
723
|
+
future = self._process_messages.get(message.cancel_message_id)
|
|
724
|
+
if future is None or future.done(): # pragma: no cover
|
|
725
|
+
raise CannotCancelTask(
|
|
726
|
+
"Task not exists, maybe it is done or cancelled already"
|
|
727
|
+
)
|
|
728
|
+
future.cancel()
|
|
729
|
+
processor.result = ResultMessage(
|
|
730
|
+
message.message_id, True, protocol=message.protocol
|
|
731
|
+
)
|
|
732
|
+
return processor.result
|
|
733
|
+
|
|
734
|
+
@staticmethod
|
|
735
|
+
def _set_global_router(router: Router):
|
|
736
|
+
# be cautious about setting global router
|
|
737
|
+
# for instance, multiple main pool may be created in the same process
|
|
738
|
+
|
|
739
|
+
# get default router or create an empty one
|
|
740
|
+
default_router = Router.get_instance_or_empty()
|
|
741
|
+
Router.set_instance(default_router)
|
|
742
|
+
# append this router to global
|
|
743
|
+
default_router.add_router(router)
|
|
744
|
+
|
|
745
|
+
@staticmethod
|
|
746
|
+
def _update_stored_addresses(
|
|
747
|
+
servers: list[Server],
|
|
748
|
+
raw_addresses: list[str],
|
|
749
|
+
actor_pool_config: ActorPoolConfig,
|
|
750
|
+
kw: dict,
|
|
751
|
+
):
|
|
752
|
+
process_index = kw["process_index"]
|
|
753
|
+
curr_pool_config = actor_pool_config.get_pool_config(process_index)
|
|
754
|
+
external_addresses = curr_pool_config["external_address"]
|
|
755
|
+
external_address_set = set(external_addresses)
|
|
756
|
+
|
|
757
|
+
kw["servers"] = servers
|
|
758
|
+
|
|
759
|
+
new_external_addresses = [
|
|
760
|
+
server.address
|
|
761
|
+
for server, raw_address in zip(servers, raw_addresses)
|
|
762
|
+
if raw_address in external_address_set
|
|
763
|
+
]
|
|
764
|
+
|
|
765
|
+
if external_address_set != set(new_external_addresses):
|
|
766
|
+
external_addresses = new_external_addresses
|
|
767
|
+
actor_pool_config.reset_pool_external_address(
|
|
768
|
+
process_index, external_addresses
|
|
769
|
+
)
|
|
770
|
+
external_addresses = curr_pool_config["external_address"]
|
|
771
|
+
|
|
772
|
+
logger.debug(
|
|
773
|
+
"External address of process index %s updated to %s",
|
|
774
|
+
process_index,
|
|
775
|
+
external_addresses[0],
|
|
776
|
+
)
|
|
777
|
+
if kw["internal_address"] == kw["external_address"]:
|
|
778
|
+
# internal address may be the same as external address in Windows
|
|
779
|
+
kw["internal_address"] = external_addresses[0]
|
|
780
|
+
kw["external_address"] = external_addresses[0]
|
|
781
|
+
|
|
782
|
+
kw["router"] = Router(
|
|
783
|
+
external_addresses,
|
|
784
|
+
gen_local_address(process_index),
|
|
785
|
+
actor_pool_config.external_to_internal_address_map,
|
|
786
|
+
comm_config=actor_pool_config.get_comm_config(),
|
|
787
|
+
proxy_config=actor_pool_config.get_proxy_config(),
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
@classmethod
|
|
791
|
+
async def _create_servers(
|
|
792
|
+
cls, addresses: list[str], channel_handler: Callable, config: dict
|
|
793
|
+
):
|
|
794
|
+
assert len(set(addresses)) == len(addresses)
|
|
795
|
+
# create servers
|
|
796
|
+
create_server_tasks = []
|
|
797
|
+
for addr in addresses:
|
|
798
|
+
server_type = get_server_type(addr)
|
|
799
|
+
extra_config = server_type.parse_config(config)
|
|
800
|
+
server_config = dict(address=addr, handle_channel=channel_handler)
|
|
801
|
+
server_config.update(extra_config)
|
|
802
|
+
task = asyncio.create_task(server_type.create(server_config))
|
|
803
|
+
create_server_tasks.append(task)
|
|
804
|
+
|
|
805
|
+
await asyncio.gather(*create_server_tasks)
|
|
806
|
+
return [f.result() for f in create_server_tasks]
|
|
807
|
+
|
|
808
|
+
@classmethod
|
|
809
|
+
@implements(AbstractActorPool.create)
|
|
810
|
+
async def create(cls, config: dict) -> "ActorPoolType":
|
|
811
|
+
config = config.copy()
|
|
812
|
+
kw: dict[str, Any] = dict()
|
|
813
|
+
cls._parse_config(config, kw)
|
|
814
|
+
process_index: int = kw["process_index"]
|
|
815
|
+
actor_pool_config = kw["config"] # type: ActorPoolConfig
|
|
816
|
+
cur_pool_config = actor_pool_config.get_pool_config(process_index)
|
|
817
|
+
external_addresses = cur_pool_config["external_address"]
|
|
818
|
+
internal_address = kw["internal_address"]
|
|
819
|
+
|
|
820
|
+
# import predefined modules
|
|
821
|
+
modules = cur_pool_config["modules"] or []
|
|
822
|
+
for mod in modules:
|
|
823
|
+
__import__(mod, globals(), locals(), [])
|
|
824
|
+
# make sure all lazy imports loaded
|
|
825
|
+
with _disable_log_temporally():
|
|
826
|
+
TypeDispatcher.reload_all_lazy_handlers()
|
|
827
|
+
|
|
828
|
+
def handle_channel(channel):
|
|
829
|
+
return pool.on_new_channel(channel)
|
|
830
|
+
|
|
831
|
+
# create servers
|
|
832
|
+
server_addresses = list(external_addresses)
|
|
833
|
+
if internal_address:
|
|
834
|
+
server_addresses.append(internal_address)
|
|
835
|
+
server_addresses.append(gen_local_address(process_index))
|
|
836
|
+
server_addresses = sorted(set(server_addresses))
|
|
837
|
+
servers = await cls._create_servers(
|
|
838
|
+
server_addresses, handle_channel, actor_pool_config.get_comm_config()
|
|
839
|
+
)
|
|
840
|
+
cls._update_stored_addresses(servers, server_addresses, actor_pool_config, kw)
|
|
841
|
+
|
|
842
|
+
# set default router
|
|
843
|
+
# actor context would be able to use exact client
|
|
844
|
+
cls._set_global_router(kw["router"])
|
|
845
|
+
|
|
846
|
+
# create pool
|
|
847
|
+
pool = cls(**kw)
|
|
848
|
+
return pool # type: ignore
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
ActorPoolType = TypeVar("ActorPoolType", bound=AbstractActorPool)
|
|
852
|
+
MainActorPoolType = TypeVar("MainActorPoolType", bound="MainActorPoolBase")
|
|
853
|
+
SubProcessHandle = asyncio.subprocess.Process
|
|
854
|
+
|
|
855
|
+
|
|
856
|
+
class SubActorPoolBase(ActorPoolBase):
|
|
857
|
+
__slots__ = ("_main_address", "_watch_main_pool_task")
|
|
858
|
+
_watch_main_pool_task: Optional[asyncio.Task]
|
|
859
|
+
|
|
860
|
+
def __init__(
|
|
861
|
+
self,
|
|
862
|
+
process_index: int,
|
|
863
|
+
label: str,
|
|
864
|
+
external_address: str,
|
|
865
|
+
internal_address: str,
|
|
866
|
+
env: dict,
|
|
867
|
+
router: Router,
|
|
868
|
+
config: ActorPoolConfig,
|
|
869
|
+
servers: list[Server],
|
|
870
|
+
main_address: str,
|
|
871
|
+
main_pool_pid: Optional[int],
|
|
872
|
+
):
|
|
873
|
+
super().__init__(
|
|
874
|
+
process_index,
|
|
875
|
+
label,
|
|
876
|
+
external_address,
|
|
877
|
+
internal_address,
|
|
878
|
+
env,
|
|
879
|
+
router,
|
|
880
|
+
config,
|
|
881
|
+
servers,
|
|
882
|
+
)
|
|
883
|
+
self._main_address = main_address
|
|
884
|
+
if main_pool_pid:
|
|
885
|
+
self._watch_main_pool_task = asyncio.create_task(
|
|
886
|
+
self._watch_main_pool(main_pool_pid)
|
|
887
|
+
)
|
|
888
|
+
else:
|
|
889
|
+
self._watch_main_pool_task = None
|
|
890
|
+
|
|
891
|
+
async def _watch_main_pool(self, main_pool_pid: int):
|
|
892
|
+
main_process = psutil.Process(main_pool_pid)
|
|
893
|
+
while not self.stopped:
|
|
894
|
+
try:
|
|
895
|
+
await asyncio.to_thread(main_process.status)
|
|
896
|
+
await asyncio.sleep(0.1)
|
|
897
|
+
continue
|
|
898
|
+
except (psutil.NoSuchProcess, ProcessLookupError, asyncio.CancelledError):
|
|
899
|
+
# main pool died
|
|
900
|
+
break
|
|
901
|
+
|
|
902
|
+
if not self.stopped:
|
|
903
|
+
await self.stop()
|
|
904
|
+
|
|
905
|
+
async def notify_main_pool_to_destroy(
|
|
906
|
+
self, message: DestroyActorMessage
|
|
907
|
+
): # pragma: no cover
|
|
908
|
+
await self.call(self._main_address, message)
|
|
909
|
+
|
|
910
|
+
async def notify_main_pool_to_create(self, message: CreateActorMessage):
|
|
911
|
+
reg_message = ControlMessage(
|
|
912
|
+
new_message_id(),
|
|
913
|
+
self.external_address,
|
|
914
|
+
ControlMessageType.add_sub_pool_actor,
|
|
915
|
+
(self.external_address, message.allocate_strategy, message),
|
|
916
|
+
)
|
|
917
|
+
await self.call(self._main_address, reg_message)
|
|
918
|
+
|
|
919
|
+
@implements(AbstractActorPool.create_actor)
|
|
920
|
+
async def create_actor(self, message: CreateActorMessage) -> ResultMessageType:
|
|
921
|
+
result = await super().create_actor(message)
|
|
922
|
+
if not message.from_main:
|
|
923
|
+
await self.notify_main_pool_to_create(message)
|
|
924
|
+
return result
|
|
925
|
+
|
|
926
|
+
@implements(AbstractActorPool.actor_ref)
|
|
927
|
+
async def actor_ref(self, message: ActorRefMessage) -> ResultMessageType:
|
|
928
|
+
result = await super().actor_ref(message)
|
|
929
|
+
if isinstance(result, ErrorMessage):
|
|
930
|
+
# need a new message id to call main actor
|
|
931
|
+
main_message = ActorRefMessage(
|
|
932
|
+
new_message_id(),
|
|
933
|
+
create_actor_ref(self._main_address, message.actor_ref.uid),
|
|
934
|
+
)
|
|
935
|
+
result = await self.call(self._main_address, main_message)
|
|
936
|
+
# rewrite to message_id of the original request
|
|
937
|
+
result.message_id = message.message_id
|
|
938
|
+
return result
|
|
939
|
+
|
|
940
|
+
@implements(AbstractActorPool.destroy_actor)
|
|
941
|
+
async def destroy_actor(self, message: DestroyActorMessage) -> ResultMessageType:
|
|
942
|
+
result = await super().destroy_actor(message)
|
|
943
|
+
if isinstance(result, ResultMessage) and not message.from_main:
|
|
944
|
+
# sync back to main actor pool
|
|
945
|
+
await self.notify_main_pool_to_destroy(message)
|
|
946
|
+
return result
|
|
947
|
+
|
|
948
|
+
@implements(AbstractActorPool.handle_control_command)
|
|
949
|
+
async def handle_control_command(
|
|
950
|
+
self, message: ControlMessage
|
|
951
|
+
) -> ResultMessageType:
|
|
952
|
+
if message.control_message_type == ControlMessageType.sync_config:
|
|
953
|
+
self._main_address = message.address
|
|
954
|
+
return await super().handle_control_command(message)
|
|
955
|
+
|
|
956
|
+
@staticmethod
|
|
957
|
+
def _parse_config(config: dict, kw: dict) -> dict:
|
|
958
|
+
main_pool_pid = config.pop("main_pool_pid", None)
|
|
959
|
+
kw = AbstractActorPool._parse_config(config, kw)
|
|
960
|
+
pool_config: ActorPoolConfig = kw["config"]
|
|
961
|
+
main_process_index = pool_config.get_process_indexes()[0]
|
|
962
|
+
kw["main_address"] = pool_config.get_pool_config(main_process_index)[
|
|
963
|
+
"external_address"
|
|
964
|
+
][0]
|
|
965
|
+
kw["main_pool_pid"] = main_pool_pid
|
|
966
|
+
return kw
|
|
967
|
+
|
|
968
|
+
async def stop(self):
|
|
969
|
+
await super().stop()
|
|
970
|
+
if self._watch_main_pool_task:
|
|
971
|
+
self._watch_main_pool_task.cancel()
|
|
972
|
+
await self._watch_main_pool_task
|
|
973
|
+
|
|
974
|
+
|
|
975
|
+
class MainActorPoolBase(ActorPoolBase):
|
|
976
|
+
__slots__ = (
|
|
977
|
+
"_allocated_actors",
|
|
978
|
+
"sub_actor_pool_manager",
|
|
979
|
+
"_auto_recover",
|
|
980
|
+
"_monitor_task",
|
|
981
|
+
"_on_process_down",
|
|
982
|
+
"_on_process_recover",
|
|
983
|
+
"_recover_events",
|
|
984
|
+
"_allocation_lock",
|
|
985
|
+
"sub_processes",
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
def __init__(
|
|
989
|
+
self,
|
|
990
|
+
process_index: int,
|
|
991
|
+
label: str,
|
|
992
|
+
external_address: str,
|
|
993
|
+
internal_address: str,
|
|
994
|
+
env: dict,
|
|
995
|
+
router: Router,
|
|
996
|
+
config: ActorPoolConfig,
|
|
997
|
+
servers: list[Server],
|
|
998
|
+
auto_recover: str | bool = "actor",
|
|
999
|
+
on_process_down: Callable[[MainActorPoolType, str], None] | None = None,
|
|
1000
|
+
on_process_recover: Callable[[MainActorPoolType, str], None] | None = None,
|
|
1001
|
+
):
|
|
1002
|
+
super().__init__(
|
|
1003
|
+
process_index,
|
|
1004
|
+
label,
|
|
1005
|
+
external_address,
|
|
1006
|
+
internal_address,
|
|
1007
|
+
env,
|
|
1008
|
+
router,
|
|
1009
|
+
config,
|
|
1010
|
+
servers,
|
|
1011
|
+
)
|
|
1012
|
+
|
|
1013
|
+
# auto recovering
|
|
1014
|
+
self._auto_recover = auto_recover
|
|
1015
|
+
self._monitor_task: Optional[asyncio.Task] = None
|
|
1016
|
+
self._on_process_down = on_process_down
|
|
1017
|
+
self._on_process_recover = on_process_recover
|
|
1018
|
+
self._recover_events: dict[str, asyncio.Event] = dict()
|
|
1019
|
+
|
|
1020
|
+
# states
|
|
1021
|
+
self._allocated_actors: allocated_type = {
|
|
1022
|
+
addr: dict() for addr in self._config.get_external_addresses()
|
|
1023
|
+
}
|
|
1024
|
+
self._allocation_lock = threading.Lock()
|
|
1025
|
+
|
|
1026
|
+
self.sub_processes: dict[str, SubProcessHandle] = dict()
|
|
1027
|
+
|
|
1028
|
+
_process_index_gen = itertools.count()
|
|
1029
|
+
|
|
1030
|
+
@classmethod
|
|
1031
|
+
def process_index_gen(cls, address):
|
|
1032
|
+
# make sure different processes does not share process indexes
|
|
1033
|
+
pid = os.getpid()
|
|
1034
|
+
for idx in cls._process_index_gen:
|
|
1035
|
+
yield pid << 16 + idx
|
|
1036
|
+
|
|
1037
|
+
@property
|
|
1038
|
+
def _sub_processes(self):
|
|
1039
|
+
return self.sub_processes
|
|
1040
|
+
|
|
1041
|
+
@implements(AbstractActorPool.create_actor)
|
|
1042
|
+
async def create_actor(self, message: CreateActorMessage) -> ResultMessageType:
|
|
1043
|
+
with _ErrorProcessor(
|
|
1044
|
+
address=self.external_address,
|
|
1045
|
+
message_id=message.message_id,
|
|
1046
|
+
protocol=message.protocol,
|
|
1047
|
+
) as processor:
|
|
1048
|
+
allocate_strategy = message.allocate_strategy
|
|
1049
|
+
with self._allocation_lock:
|
|
1050
|
+
# get allocated address according to corresponding strategy
|
|
1051
|
+
address = allocate_strategy.get_allocated_address(
|
|
1052
|
+
self._config, self._allocated_actors
|
|
1053
|
+
)
|
|
1054
|
+
# set placeholder to make sure this label is occupied
|
|
1055
|
+
self._allocated_actors[address][None] = (allocate_strategy, message)
|
|
1056
|
+
if address == self.external_address:
|
|
1057
|
+
# creating actor on main actor pool
|
|
1058
|
+
result = await super().create_actor(message)
|
|
1059
|
+
if isinstance(result, ResultMessage):
|
|
1060
|
+
self._allocated_actors[self.external_address][result.result] = (
|
|
1061
|
+
allocate_strategy,
|
|
1062
|
+
message,
|
|
1063
|
+
)
|
|
1064
|
+
processor.result = result
|
|
1065
|
+
else:
|
|
1066
|
+
# creating actor on sub actor pool
|
|
1067
|
+
# rewrite allocate strategy to AddressSpecified
|
|
1068
|
+
new_allocate_strategy = AddressSpecified(address)
|
|
1069
|
+
new_create_actor_message = CreateActorMessage(
|
|
1070
|
+
message.message_id,
|
|
1071
|
+
message.actor_cls,
|
|
1072
|
+
message.actor_id,
|
|
1073
|
+
message.args,
|
|
1074
|
+
message.kwargs,
|
|
1075
|
+
allocate_strategy=new_allocate_strategy,
|
|
1076
|
+
from_main=True,
|
|
1077
|
+
protocol=message.protocol,
|
|
1078
|
+
message_trace=message.message_trace,
|
|
1079
|
+
)
|
|
1080
|
+
result = await self.call(address, new_create_actor_message)
|
|
1081
|
+
if isinstance(result, ResultMessage):
|
|
1082
|
+
self._allocated_actors[address][result.result] = (
|
|
1083
|
+
allocate_strategy,
|
|
1084
|
+
new_create_actor_message,
|
|
1085
|
+
)
|
|
1086
|
+
processor.result = result
|
|
1087
|
+
|
|
1088
|
+
# revert placeholder
|
|
1089
|
+
self._allocated_actors[address].pop(None, None)
|
|
1090
|
+
|
|
1091
|
+
return processor.result
|
|
1092
|
+
|
|
1093
|
+
@implements(AbstractActorPool.has_actor)
|
|
1094
|
+
async def has_actor(self, message: HasActorMessage) -> ResultMessage:
|
|
1095
|
+
actor_ref = message.actor_ref
|
|
1096
|
+
# lookup allocated
|
|
1097
|
+
for address, item in self._allocated_actors.items():
|
|
1098
|
+
ref = create_actor_ref(address, to_binary(actor_ref.uid))
|
|
1099
|
+
if ref in item:
|
|
1100
|
+
return ResultMessage(
|
|
1101
|
+
message.message_id, True, protocol=message.protocol
|
|
1102
|
+
)
|
|
1103
|
+
|
|
1104
|
+
return ResultMessage(message.message_id, False, protocol=message.protocol)
|
|
1105
|
+
|
|
1106
|
+
@implements(AbstractActorPool.destroy_actor)
|
|
1107
|
+
async def destroy_actor(self, message: DestroyActorMessage) -> ResultMessageType:
|
|
1108
|
+
actor_ref_message = ActorRefMessage(
|
|
1109
|
+
message.message_id, message.actor_ref, protocol=message.protocol
|
|
1110
|
+
)
|
|
1111
|
+
result = await self.actor_ref(actor_ref_message)
|
|
1112
|
+
if not isinstance(result, ResultMessage):
|
|
1113
|
+
return result
|
|
1114
|
+
real_actor_ref = result.result
|
|
1115
|
+
if real_actor_ref.address == self.external_address:
|
|
1116
|
+
result = await super().destroy_actor(message)
|
|
1117
|
+
if result.message_type == MessageType.error:
|
|
1118
|
+
return result
|
|
1119
|
+
del self._allocated_actors[self.external_address][real_actor_ref]
|
|
1120
|
+
return ResultMessage(
|
|
1121
|
+
message.message_id, real_actor_ref.uid, protocol=message.protocol
|
|
1122
|
+
)
|
|
1123
|
+
# remove allocated actor ref
|
|
1124
|
+
self._allocated_actors[real_actor_ref.address].pop(real_actor_ref, None)
|
|
1125
|
+
new_destroy_message = DestroyActorMessage(
|
|
1126
|
+
message.message_id,
|
|
1127
|
+
real_actor_ref,
|
|
1128
|
+
from_main=True,
|
|
1129
|
+
protocol=message.protocol,
|
|
1130
|
+
)
|
|
1131
|
+
return await self.call(real_actor_ref.address, new_destroy_message)
|
|
1132
|
+
|
|
1133
|
+
@implements(AbstractActorPool.send)
|
|
1134
|
+
async def send(self, message: SendMessage) -> ResultMessageType:
|
|
1135
|
+
if message.actor_ref.uid in self._actors:
|
|
1136
|
+
return await super().send(message)
|
|
1137
|
+
actor_ref_message = ActorRefMessage(
|
|
1138
|
+
message.message_id, message.actor_ref, protocol=message.protocol
|
|
1139
|
+
)
|
|
1140
|
+
result = await self.actor_ref(actor_ref_message)
|
|
1141
|
+
if not isinstance(result, ResultMessage):
|
|
1142
|
+
return result
|
|
1143
|
+
actor_ref = result.result
|
|
1144
|
+
new_send_message = SendMessage(
|
|
1145
|
+
message.message_id,
|
|
1146
|
+
actor_ref,
|
|
1147
|
+
message.content,
|
|
1148
|
+
protocol=message.protocol,
|
|
1149
|
+
message_trace=message.message_trace,
|
|
1150
|
+
)
|
|
1151
|
+
return await self.call(actor_ref.address, new_send_message)
|
|
1152
|
+
|
|
1153
|
+
@implements(AbstractActorPool.tell)
|
|
1154
|
+
async def tell(self, message: TellMessage) -> ResultMessageType:
|
|
1155
|
+
if message.actor_ref.uid in self._actors:
|
|
1156
|
+
return await super().tell(message)
|
|
1157
|
+
actor_ref_message = ActorRefMessage(
|
|
1158
|
+
message.message_id, message.actor_ref, protocol=message.protocol
|
|
1159
|
+
)
|
|
1160
|
+
result = await self.actor_ref(actor_ref_message)
|
|
1161
|
+
if not isinstance(result, ResultMessage):
|
|
1162
|
+
return result
|
|
1163
|
+
actor_ref = result.result
|
|
1164
|
+
new_tell_message = TellMessage(
|
|
1165
|
+
message.message_id,
|
|
1166
|
+
actor_ref,
|
|
1167
|
+
message.content,
|
|
1168
|
+
protocol=message.protocol,
|
|
1169
|
+
message_trace=message.message_trace,
|
|
1170
|
+
)
|
|
1171
|
+
return await self.call(actor_ref.address, new_tell_message)
|
|
1172
|
+
|
|
1173
|
+
@implements(AbstractActorPool.actor_ref)
|
|
1174
|
+
async def actor_ref(self, message: ActorRefMessage) -> ResultMessageType:
|
|
1175
|
+
actor_ref = message.actor_ref
|
|
1176
|
+
actor_ref.uid = to_binary(actor_ref.uid)
|
|
1177
|
+
if actor_ref.address == self.external_address and actor_ref.uid in self._actors:
|
|
1178
|
+
actor_ref.proxy_addresses = self._router.get_proxies(actor_ref.address)
|
|
1179
|
+
return ResultMessage(
|
|
1180
|
+
message.message_id, actor_ref, protocol=message.protocol
|
|
1181
|
+
)
|
|
1182
|
+
|
|
1183
|
+
# lookup allocated
|
|
1184
|
+
for address, item in self._allocated_actors.items():
|
|
1185
|
+
ref = create_actor_ref(address, actor_ref.uid)
|
|
1186
|
+
if ref in item:
|
|
1187
|
+
ref.proxy_addresses = self._router.get_proxies(ref.address)
|
|
1188
|
+
return ResultMessage(message.message_id, ref, protocol=message.protocol)
|
|
1189
|
+
|
|
1190
|
+
with _ErrorProcessor(
|
|
1191
|
+
self.external_address, message.message_id, protocol=message.protocol
|
|
1192
|
+
) as processor:
|
|
1193
|
+
raise ActorNotExist(
|
|
1194
|
+
f"Actor {actor_ref.uid} does not exist in {actor_ref.address}"
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
return processor.result
|
|
1198
|
+
|
|
1199
|
+
@implements(AbstractActorPool.cancel)
|
|
1200
|
+
async def cancel(self, message: CancelMessage) -> ResultMessageType:
|
|
1201
|
+
if message.address == self.external_address:
|
|
1202
|
+
# local message
|
|
1203
|
+
return await super().cancel(message)
|
|
1204
|
+
# redirect to sub pool
|
|
1205
|
+
return await self.call(message.address, message)
|
|
1206
|
+
|
|
1207
|
+
@implements(AbstractActorPool.handle_control_command)
|
|
1208
|
+
async def handle_control_command(
|
|
1209
|
+
self, message: ControlMessage
|
|
1210
|
+
) -> ResultMessageType:
|
|
1211
|
+
with _ErrorProcessor(
|
|
1212
|
+
self.external_address, message.message_id, message.protocol
|
|
1213
|
+
) as processor:
|
|
1214
|
+
if message.address == self.external_address:
|
|
1215
|
+
if message.control_message_type == ControlMessageType.sync_config:
|
|
1216
|
+
# sync config, need to notify all sub pools
|
|
1217
|
+
tasks = []
|
|
1218
|
+
for addr in self.sub_processes:
|
|
1219
|
+
control_message = ControlMessage(
|
|
1220
|
+
new_message_id(),
|
|
1221
|
+
message.address,
|
|
1222
|
+
message.control_message_type,
|
|
1223
|
+
message.content,
|
|
1224
|
+
protocol=message.protocol,
|
|
1225
|
+
message_trace=message.message_trace,
|
|
1226
|
+
)
|
|
1227
|
+
tasks.append(
|
|
1228
|
+
asyncio.create_task(self.call(addr, control_message))
|
|
1229
|
+
)
|
|
1230
|
+
# call super
|
|
1231
|
+
task = asyncio.create_task(super().handle_control_command(message))
|
|
1232
|
+
tasks.append(task)
|
|
1233
|
+
await asyncio.gather(*tasks)
|
|
1234
|
+
processor.result = await task
|
|
1235
|
+
else:
|
|
1236
|
+
processor.result = await super().handle_control_command(message)
|
|
1237
|
+
elif message.control_message_type == ControlMessageType.stop:
|
|
1238
|
+
timeout, force = (
|
|
1239
|
+
message.content if message.content is not None else (None, False)
|
|
1240
|
+
)
|
|
1241
|
+
await self.stop_sub_pool(
|
|
1242
|
+
message.address,
|
|
1243
|
+
self.sub_processes[message.address],
|
|
1244
|
+
timeout=timeout,
|
|
1245
|
+
force=force,
|
|
1246
|
+
)
|
|
1247
|
+
processor.result = ResultMessage(
|
|
1248
|
+
message.message_id, True, protocol=message.protocol
|
|
1249
|
+
)
|
|
1250
|
+
elif message.control_message_type == ControlMessageType.wait_pool_recovered:
|
|
1251
|
+
if self._auto_recover and message.address not in self._recover_events:
|
|
1252
|
+
self._recover_events[message.address] = asyncio.Event()
|
|
1253
|
+
|
|
1254
|
+
event = self._recover_events.get(message.address, None)
|
|
1255
|
+
if event is not None:
|
|
1256
|
+
await event.wait()
|
|
1257
|
+
processor.result = ResultMessage(
|
|
1258
|
+
message.message_id, True, protocol=message.protocol
|
|
1259
|
+
)
|
|
1260
|
+
elif message.control_message_type == ControlMessageType.add_sub_pool_actor:
|
|
1261
|
+
address, allocate_strategy, create_message = message.content
|
|
1262
|
+
create_message.from_main = True
|
|
1263
|
+
ref = create_actor_ref(address, to_binary(create_message.actor_id))
|
|
1264
|
+
self._allocated_actors[address][ref] = (
|
|
1265
|
+
allocate_strategy,
|
|
1266
|
+
create_message,
|
|
1267
|
+
)
|
|
1268
|
+
processor.result = ResultMessage(
|
|
1269
|
+
message.message_id, True, protocol=message.protocol
|
|
1270
|
+
)
|
|
1271
|
+
else:
|
|
1272
|
+
processor.result = await self.call(message.address, message)
|
|
1273
|
+
return processor.result
|
|
1274
|
+
|
|
1275
|
+
@staticmethod
|
|
1276
|
+
def _parse_config(config: dict, kw: dict) -> dict:
|
|
1277
|
+
kw["auto_recover"] = config.pop("auto_recover", "actor")
|
|
1278
|
+
kw["on_process_down"] = config.pop("on_process_down", None)
|
|
1279
|
+
kw["on_process_recover"] = config.pop("on_process_recover", None)
|
|
1280
|
+
kw = AbstractActorPool._parse_config(config, kw)
|
|
1281
|
+
return kw
|
|
1282
|
+
|
|
1283
|
+
@classmethod
|
|
1284
|
+
@implements(AbstractActorPool.create)
|
|
1285
|
+
async def create(cls, config: dict) -> MainActorPoolType:
|
|
1286
|
+
config = config.copy()
|
|
1287
|
+
actor_pool_config: ActorPoolConfig = config.get("actor_pool_config") # type: ignore
|
|
1288
|
+
if "process_index" not in config:
|
|
1289
|
+
config["process_index"] = actor_pool_config.get_process_indexes()[0]
|
|
1290
|
+
curr_process_index = config.get("process_index")
|
|
1291
|
+
old_config_addresses = set(actor_pool_config.get_external_addresses())
|
|
1292
|
+
|
|
1293
|
+
tasks = []
|
|
1294
|
+
subpool_process_idxes = []
|
|
1295
|
+
# create sub actor pools
|
|
1296
|
+
n_sub_pool = actor_pool_config.n_pool - 1
|
|
1297
|
+
if n_sub_pool > 0:
|
|
1298
|
+
process_indexes = actor_pool_config.get_process_indexes()
|
|
1299
|
+
for process_index in process_indexes:
|
|
1300
|
+
if process_index == curr_process_index:
|
|
1301
|
+
continue
|
|
1302
|
+
create_pool_task = asyncio.create_task(
|
|
1303
|
+
cls.start_sub_pool(actor_pool_config, process_index)
|
|
1304
|
+
)
|
|
1305
|
+
await asyncio.sleep(0)
|
|
1306
|
+
# await create_pool_task
|
|
1307
|
+
tasks.append(create_pool_task)
|
|
1308
|
+
subpool_process_idxes.append(process_index)
|
|
1309
|
+
|
|
1310
|
+
processes, ext_addresses = await cls.wait_sub_pools_ready(tasks)
|
|
1311
|
+
if ext_addresses:
|
|
1312
|
+
for process_index, ext_address in zip(subpool_process_idxes, ext_addresses):
|
|
1313
|
+
actor_pool_config.reset_pool_external_address(
|
|
1314
|
+
process_index, ext_address
|
|
1315
|
+
)
|
|
1316
|
+
|
|
1317
|
+
# create main actor pool
|
|
1318
|
+
pool: MainActorPoolType = await super().create(config)
|
|
1319
|
+
addresses = actor_pool_config.get_external_addresses()[1:]
|
|
1320
|
+
|
|
1321
|
+
assert len(addresses) == len(
|
|
1322
|
+
processes
|
|
1323
|
+
), f"addresses {addresses}, processes {processes}"
|
|
1324
|
+
for addr, proc in zip(addresses, processes):
|
|
1325
|
+
pool.attach_sub_process(addr, proc)
|
|
1326
|
+
|
|
1327
|
+
new_config_addresses = set(actor_pool_config.get_external_addresses())
|
|
1328
|
+
if old_config_addresses != new_config_addresses:
|
|
1329
|
+
control_message = ControlMessage(
|
|
1330
|
+
message_id=new_message_id(),
|
|
1331
|
+
address=pool.external_address,
|
|
1332
|
+
control_message_type=ControlMessageType.sync_config,
|
|
1333
|
+
content=actor_pool_config,
|
|
1334
|
+
)
|
|
1335
|
+
await pool.handle_control_command(control_message)
|
|
1336
|
+
|
|
1337
|
+
return pool
|
|
1338
|
+
|
|
1339
|
+
async def start_monitor(self):
|
|
1340
|
+
# Only start monitor if there are sub processes to monitor
|
|
1341
|
+
# This prevents hanging when n_process=0
|
|
1342
|
+
if self._monitor_task is None and self.sub_processes:
|
|
1343
|
+
self._monitor_task = asyncio.create_task(self.monitor_sub_pools())
|
|
1344
|
+
return self._monitor_task
|
|
1345
|
+
|
|
1346
|
+
@implements(AbstractActorPool.stop)
|
|
1347
|
+
async def stop(self):
|
|
1348
|
+
global_router = Router.get_instance()
|
|
1349
|
+
if global_router is not None:
|
|
1350
|
+
global_router.remove_router(self._router)
|
|
1351
|
+
|
|
1352
|
+
# turn off auto recover to avoid errors
|
|
1353
|
+
self._auto_recover = False
|
|
1354
|
+
self._stopped.set()
|
|
1355
|
+
if self._monitor_task and not self._monitor_task.done():
|
|
1356
|
+
# Cancel the monitor task to ensure it exits immediately
|
|
1357
|
+
self._monitor_task.cancel()
|
|
1358
|
+
try:
|
|
1359
|
+
await self._monitor_task
|
|
1360
|
+
except asyncio.CancelledError:
|
|
1361
|
+
pass # Expected when cancelling the task
|
|
1362
|
+
self._monitor_task = None
|
|
1363
|
+
await self.stop_sub_pools()
|
|
1364
|
+
await super().stop()
|
|
1365
|
+
|
|
1366
|
+
@classmethod
|
|
1367
|
+
@abstractmethod
|
|
1368
|
+
async def start_sub_pool(
|
|
1369
|
+
cls,
|
|
1370
|
+
actor_pool_config: ActorPoolConfig,
|
|
1371
|
+
process_index: int,
|
|
1372
|
+
start_python: str | None = None,
|
|
1373
|
+
):
|
|
1374
|
+
"""Start a sub actor pool"""
|
|
1375
|
+
|
|
1376
|
+
@classmethod
|
|
1377
|
+
@abstractmethod
|
|
1378
|
+
async def wait_sub_pools_ready(cls, create_pool_tasks: list[asyncio.Task]):
|
|
1379
|
+
"""Wait all sub pools ready"""
|
|
1380
|
+
|
|
1381
|
+
def attach_sub_process(self, external_address: str, process: SubProcessHandle):
|
|
1382
|
+
self.sub_processes[external_address] = process
|
|
1383
|
+
|
|
1384
|
+
async def stop_sub_pools(self):
|
|
1385
|
+
to_stop_processes: dict[str, SubProcessHandle] = dict() # type: ignore
|
|
1386
|
+
for address, process in self.sub_processes.items():
|
|
1387
|
+
if not await self.is_sub_pool_alive(process):
|
|
1388
|
+
continue
|
|
1389
|
+
to_stop_processes[address] = process
|
|
1390
|
+
|
|
1391
|
+
tasks = []
|
|
1392
|
+
for address, process in to_stop_processes.items():
|
|
1393
|
+
tasks.append(self.stop_sub_pool(address, process))
|
|
1394
|
+
await asyncio.gather(*tasks)
|
|
1395
|
+
|
|
1396
|
+
async def stop_sub_pool(
|
|
1397
|
+
self,
|
|
1398
|
+
address: str,
|
|
1399
|
+
process: SubProcessHandle,
|
|
1400
|
+
timeout: float | None = None,
|
|
1401
|
+
force: bool = False,
|
|
1402
|
+
):
|
|
1403
|
+
if force:
|
|
1404
|
+
await self.kill_sub_pool(process, force=True)
|
|
1405
|
+
return
|
|
1406
|
+
|
|
1407
|
+
stop_message = ControlMessage(
|
|
1408
|
+
new_message_id(),
|
|
1409
|
+
address,
|
|
1410
|
+
ControlMessageType.stop,
|
|
1411
|
+
None,
|
|
1412
|
+
protocol=DEFAULT_PROTOCOL,
|
|
1413
|
+
)
|
|
1414
|
+
try:
|
|
1415
|
+
if timeout is None:
|
|
1416
|
+
# Use a short timeout for graceful shutdown to avoid hanging
|
|
1417
|
+
timeout = 2.0
|
|
1418
|
+
|
|
1419
|
+
call = asyncio.create_task(self.call(address, stop_message))
|
|
1420
|
+
try:
|
|
1421
|
+
await asyncio.wait_for(call, timeout)
|
|
1422
|
+
except (futures.TimeoutError, asyncio.TimeoutError):
|
|
1423
|
+
force = True
|
|
1424
|
+
except (ConnectionError, ServerClosed):
|
|
1425
|
+
# process dead maybe, ignore it
|
|
1426
|
+
force = True
|
|
1427
|
+
# kill process
|
|
1428
|
+
await self.kill_sub_pool(process, force=force)
|
|
1429
|
+
|
|
1430
|
+
@abstractmethod
|
|
1431
|
+
async def kill_sub_pool(self, process: SubProcessHandle, force: bool = False):
|
|
1432
|
+
"""Kill a sub actor pool"""
|
|
1433
|
+
|
|
1434
|
+
@abstractmethod
|
|
1435
|
+
async def is_sub_pool_alive(self, process: SubProcessHandle):
|
|
1436
|
+
"""
|
|
1437
|
+
Check whether sub pool process is alive
|
|
1438
|
+
Parameters
|
|
1439
|
+
----------
|
|
1440
|
+
process : SubProcessHandle
|
|
1441
|
+
sub pool process handle
|
|
1442
|
+
Returns
|
|
1443
|
+
-------
|
|
1444
|
+
bool
|
|
1445
|
+
"""
|
|
1446
|
+
|
|
1447
|
+
@abstractmethod
|
|
1448
|
+
def recover_sub_pool(self, address):
|
|
1449
|
+
"""Recover a sub actor pool"""
|
|
1450
|
+
|
|
1451
|
+
def process_sub_pool_lost(self, address: str):
|
|
1452
|
+
if self._auto_recover in (False, "process"):
|
|
1453
|
+
# process down, when not auto_recover
|
|
1454
|
+
# or only recover process, remove all created actors
|
|
1455
|
+
self._allocated_actors[address] = dict()
|
|
1456
|
+
|
|
1457
|
+
async def monitor_sub_pools(self):
|
|
1458
|
+
try:
|
|
1459
|
+
while not self._stopped.is_set():
|
|
1460
|
+
# Copy sub_processes to avoid changes during recover.
|
|
1461
|
+
for address, process in list(self.sub_processes.items()):
|
|
1462
|
+
try:
|
|
1463
|
+
recover_events_discovered = address in self._recover_events
|
|
1464
|
+
if not await self.is_sub_pool_alive(
|
|
1465
|
+
process
|
|
1466
|
+
): # pragma: no cover
|
|
1467
|
+
if self._on_process_down is not None:
|
|
1468
|
+
self._on_process_down(self, address)
|
|
1469
|
+
self.process_sub_pool_lost(address)
|
|
1470
|
+
if self._auto_recover:
|
|
1471
|
+
await self.recover_sub_pool(address)
|
|
1472
|
+
if self._on_process_recover is not None:
|
|
1473
|
+
self._on_process_recover(self, address)
|
|
1474
|
+
if recover_events_discovered:
|
|
1475
|
+
event = self._recover_events.pop(address)
|
|
1476
|
+
event.set()
|
|
1477
|
+
except asyncio.CancelledError:
|
|
1478
|
+
raise
|
|
1479
|
+
except RuntimeError as ex: # pragma: no cover
|
|
1480
|
+
if "cannot schedule new futures" not in str(ex):
|
|
1481
|
+
# to silence log when process exit, otherwise it
|
|
1482
|
+
# will raise "RuntimeError: cannot schedule new futures
|
|
1483
|
+
# after interpreter shutdown".
|
|
1484
|
+
logger.exception("Monitor sub pool %s failed", address)
|
|
1485
|
+
except Exception:
|
|
1486
|
+
# log the exception instead of stop monitoring the
|
|
1487
|
+
# sub pool silently.
|
|
1488
|
+
logger.exception("Monitor sub pool %s failed", address)
|
|
1489
|
+
|
|
1490
|
+
# check every half second
|
|
1491
|
+
await asyncio.sleep(0.5)
|
|
1492
|
+
except asyncio.CancelledError: # pragma: no cover
|
|
1493
|
+
# cancelled
|
|
1494
|
+
return
|
|
1495
|
+
|
|
1496
|
+
@classmethod
|
|
1497
|
+
@abstractmethod
|
|
1498
|
+
def get_external_addresses(
|
|
1499
|
+
cls,
|
|
1500
|
+
address: str,
|
|
1501
|
+
n_process: int | None = None,
|
|
1502
|
+
ports: list[int] | None = None,
|
|
1503
|
+
schemes: list[Optional[str]] | None = None,
|
|
1504
|
+
):
|
|
1505
|
+
"""Returns external addresses for n pool processes"""
|
|
1506
|
+
|
|
1507
|
+
@classmethod
|
|
1508
|
+
@abstractmethod
|
|
1509
|
+
def gen_internal_address(
|
|
1510
|
+
cls, process_index: int, external_address: str | None = None
|
|
1511
|
+
) -> str | None:
|
|
1512
|
+
"""Returns internal address for pool of specified process index"""
|
|
1513
|
+
|
|
1514
|
+
|
|
1515
|
+
async def create_actor_pool(
|
|
1516
|
+
address: str,
|
|
1517
|
+
*,
|
|
1518
|
+
pool_cls: Type[MainActorPoolType] | None = None,
|
|
1519
|
+
n_process: int | None = None,
|
|
1520
|
+
labels: list[str] | None = None,
|
|
1521
|
+
ports: list[int] | None = None,
|
|
1522
|
+
envs: list[dict] | None = None,
|
|
1523
|
+
external_address_schemes: list[Optional[str]] | None = None,
|
|
1524
|
+
enable_internal_addresses: list[bool] | None = None,
|
|
1525
|
+
auto_recover: str | bool = "actor",
|
|
1526
|
+
modules: list[str] | None = None,
|
|
1527
|
+
suspend_sigint: bool | None = None,
|
|
1528
|
+
use_uvloop: str | bool = "auto",
|
|
1529
|
+
logging_conf: dict | None = None,
|
|
1530
|
+
proxy_conf: dict | None = None,
|
|
1531
|
+
on_process_down: Callable[[MainActorPoolType, str], None] | None = None,
|
|
1532
|
+
on_process_recover: Callable[[MainActorPoolType, str], None] | None = None,
|
|
1533
|
+
extra_conf: dict | None = None,
|
|
1534
|
+
**kwargs,
|
|
1535
|
+
) -> MainActorPoolType:
|
|
1536
|
+
if n_process is None:
|
|
1537
|
+
n_process = multiprocessing.cpu_count()
|
|
1538
|
+
if labels and len(labels) != n_process + 1:
|
|
1539
|
+
raise ValueError(
|
|
1540
|
+
f"`labels` should be of size {n_process + 1}, got {len(labels)}"
|
|
1541
|
+
)
|
|
1542
|
+
if envs and len(envs) != n_process:
|
|
1543
|
+
raise ValueError(f"`envs` should be of size {n_process}, got {len(envs)}")
|
|
1544
|
+
if external_address_schemes and len(external_address_schemes) != n_process + 1:
|
|
1545
|
+
raise ValueError(
|
|
1546
|
+
f"`external_address_schemes` should be of size {n_process + 1}, "
|
|
1547
|
+
f"got {len(external_address_schemes)}"
|
|
1548
|
+
)
|
|
1549
|
+
if enable_internal_addresses and len(enable_internal_addresses) != n_process + 1:
|
|
1550
|
+
raise ValueError(
|
|
1551
|
+
f"`enable_internal_addresses` should be of size {n_process + 1}, "
|
|
1552
|
+
f"got {len(enable_internal_addresses)}"
|
|
1553
|
+
)
|
|
1554
|
+
elif not enable_internal_addresses:
|
|
1555
|
+
enable_internal_addresses = [True] * (n_process + 1)
|
|
1556
|
+
if auto_recover is True:
|
|
1557
|
+
auto_recover = "actor"
|
|
1558
|
+
if auto_recover not in ("actor", "process", False):
|
|
1559
|
+
raise ValueError(
|
|
1560
|
+
f'`auto_recover` should be one of "actor", "process", '
|
|
1561
|
+
f"True or False, got {auto_recover}"
|
|
1562
|
+
)
|
|
1563
|
+
if use_uvloop == "auto":
|
|
1564
|
+
try:
|
|
1565
|
+
import uvloop # noqa: F401 # pylint: disable=unused-variable
|
|
1566
|
+
|
|
1567
|
+
use_uvloop = True
|
|
1568
|
+
except ImportError:
|
|
1569
|
+
use_uvloop = False
|
|
1570
|
+
|
|
1571
|
+
assert pool_cls is not None
|
|
1572
|
+
external_addresses = pool_cls.get_external_addresses(
|
|
1573
|
+
address, n_process=n_process, ports=ports, schemes=external_address_schemes
|
|
1574
|
+
)
|
|
1575
|
+
actor_pool_config = ActorPoolConfig()
|
|
1576
|
+
actor_pool_config.add_metric_configs(kwargs.get("metrics", {}))
|
|
1577
|
+
# add proxy config
|
|
1578
|
+
actor_pool_config.add_proxy_config(proxy_conf)
|
|
1579
|
+
# add main config
|
|
1580
|
+
process_index_gen = pool_cls.process_index_gen(address)
|
|
1581
|
+
main_process_index = next(process_index_gen)
|
|
1582
|
+
main_internal_address = (
|
|
1583
|
+
pool_cls.gen_internal_address(main_process_index, external_addresses[0])
|
|
1584
|
+
if enable_internal_addresses[0]
|
|
1585
|
+
else None
|
|
1586
|
+
)
|
|
1587
|
+
actor_pool_config.add_pool_conf(
|
|
1588
|
+
main_process_index,
|
|
1589
|
+
labels[0] if labels else None,
|
|
1590
|
+
main_internal_address,
|
|
1591
|
+
external_addresses[0],
|
|
1592
|
+
modules=modules,
|
|
1593
|
+
suspend_sigint=suspend_sigint,
|
|
1594
|
+
use_uvloop=use_uvloop, # type: ignore
|
|
1595
|
+
logging_conf=logging_conf,
|
|
1596
|
+
kwargs=kwargs,
|
|
1597
|
+
)
|
|
1598
|
+
# add sub configs
|
|
1599
|
+
for i in range(n_process):
|
|
1600
|
+
sub_process_index = next(process_index_gen)
|
|
1601
|
+
internal_address = (
|
|
1602
|
+
pool_cls.gen_internal_address(sub_process_index, external_addresses[i + 1])
|
|
1603
|
+
if enable_internal_addresses[i + 1]
|
|
1604
|
+
else None
|
|
1605
|
+
)
|
|
1606
|
+
actor_pool_config.add_pool_conf(
|
|
1607
|
+
sub_process_index,
|
|
1608
|
+
labels[i + 1] if labels else None,
|
|
1609
|
+
internal_address,
|
|
1610
|
+
external_addresses[i + 1],
|
|
1611
|
+
env=envs[i] if envs else None,
|
|
1612
|
+
modules=modules,
|
|
1613
|
+
suspend_sigint=suspend_sigint,
|
|
1614
|
+
use_uvloop=use_uvloop, # type: ignore
|
|
1615
|
+
logging_conf=logging_conf,
|
|
1616
|
+
kwargs=kwargs,
|
|
1617
|
+
)
|
|
1618
|
+
actor_pool_config.add_comm_config(extra_conf)
|
|
1619
|
+
|
|
1620
|
+
pool: MainActorPoolType = await pool_cls.create(
|
|
1621
|
+
{
|
|
1622
|
+
"actor_pool_config": actor_pool_config,
|
|
1623
|
+
"process_index": main_process_index,
|
|
1624
|
+
"auto_recover": auto_recover,
|
|
1625
|
+
"on_process_down": on_process_down,
|
|
1626
|
+
"on_process_recover": on_process_recover,
|
|
1627
|
+
}
|
|
1628
|
+
)
|
|
1629
|
+
await pool.start()
|
|
1630
|
+
return pool
|