xoscar 0.4.0__cp312-cp312-macosx_10_9_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xoscar might be problematic. Click here for more details.

Files changed (82) hide show
  1. xoscar/__init__.py +60 -0
  2. xoscar/_utils.cpython-312-darwin.so +0 -0
  3. xoscar/_utils.pxd +36 -0
  4. xoscar/_utils.pyx +241 -0
  5. xoscar/_version.py +693 -0
  6. xoscar/aio/__init__.py +16 -0
  7. xoscar/aio/base.py +86 -0
  8. xoscar/aio/file.py +59 -0
  9. xoscar/aio/lru.py +228 -0
  10. xoscar/aio/parallelism.py +39 -0
  11. xoscar/api.py +493 -0
  12. xoscar/backend.py +67 -0
  13. xoscar/backends/__init__.py +14 -0
  14. xoscar/backends/allocate_strategy.py +160 -0
  15. xoscar/backends/communication/__init__.py +30 -0
  16. xoscar/backends/communication/base.py +315 -0
  17. xoscar/backends/communication/core.py +69 -0
  18. xoscar/backends/communication/dummy.py +242 -0
  19. xoscar/backends/communication/errors.py +20 -0
  20. xoscar/backends/communication/socket.py +414 -0
  21. xoscar/backends/communication/ucx.py +531 -0
  22. xoscar/backends/communication/utils.py +97 -0
  23. xoscar/backends/config.py +145 -0
  24. xoscar/backends/context.py +404 -0
  25. xoscar/backends/core.py +193 -0
  26. xoscar/backends/indigen/__init__.py +16 -0
  27. xoscar/backends/indigen/backend.py +51 -0
  28. xoscar/backends/indigen/driver.py +26 -0
  29. xoscar/backends/indigen/pool.py +469 -0
  30. xoscar/backends/message.cpython-312-darwin.so +0 -0
  31. xoscar/backends/message.pyi +239 -0
  32. xoscar/backends/message.pyx +599 -0
  33. xoscar/backends/pool.py +1596 -0
  34. xoscar/backends/router.py +207 -0
  35. xoscar/backends/test/__init__.py +16 -0
  36. xoscar/backends/test/backend.py +38 -0
  37. xoscar/backends/test/pool.py +208 -0
  38. xoscar/batch.py +256 -0
  39. xoscar/collective/__init__.py +27 -0
  40. xoscar/collective/common.py +102 -0
  41. xoscar/collective/core.py +737 -0
  42. xoscar/collective/process_group.py +687 -0
  43. xoscar/collective/utils.py +41 -0
  44. xoscar/collective/xoscar_pygloo.cpython-312-darwin.so +0 -0
  45. xoscar/collective/xoscar_pygloo.pyi +239 -0
  46. xoscar/constants.py +21 -0
  47. xoscar/context.cpython-312-darwin.so +0 -0
  48. xoscar/context.pxd +21 -0
  49. xoscar/context.pyx +368 -0
  50. xoscar/core.cpython-312-darwin.so +0 -0
  51. xoscar/core.pxd +50 -0
  52. xoscar/core.pyx +658 -0
  53. xoscar/debug.py +188 -0
  54. xoscar/driver.py +42 -0
  55. xoscar/errors.py +63 -0
  56. xoscar/libcpp.pxd +31 -0
  57. xoscar/metrics/__init__.py +21 -0
  58. xoscar/metrics/api.py +288 -0
  59. xoscar/metrics/backends/__init__.py +13 -0
  60. xoscar/metrics/backends/console/__init__.py +13 -0
  61. xoscar/metrics/backends/console/console_metric.py +82 -0
  62. xoscar/metrics/backends/metric.py +149 -0
  63. xoscar/metrics/backends/prometheus/__init__.py +13 -0
  64. xoscar/metrics/backends/prometheus/prometheus_metric.py +70 -0
  65. xoscar/nvutils.py +717 -0
  66. xoscar/profiling.py +260 -0
  67. xoscar/serialization/__init__.py +20 -0
  68. xoscar/serialization/aio.py +138 -0
  69. xoscar/serialization/core.cpython-312-darwin.so +0 -0
  70. xoscar/serialization/core.pxd +28 -0
  71. xoscar/serialization/core.pyi +57 -0
  72. xoscar/serialization/core.pyx +944 -0
  73. xoscar/serialization/cuda.py +111 -0
  74. xoscar/serialization/exception.py +48 -0
  75. xoscar/serialization/numpy.py +82 -0
  76. xoscar/serialization/pyfury.py +37 -0
  77. xoscar/serialization/scipy.py +72 -0
  78. xoscar/utils.py +517 -0
  79. xoscar-0.4.0.dist-info/METADATA +223 -0
  80. xoscar-0.4.0.dist-info/RECORD +82 -0
  81. xoscar-0.4.0.dist-info/WHEEL +5 -0
  82. xoscar-0.4.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1596 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import concurrent.futures as futures
20
+ import contextlib
21
+ import itertools
22
+ import logging
23
+ import multiprocessing
24
+ import os
25
+ import threading
26
+ import traceback
27
+ from abc import ABC, ABCMeta, abstractmethod
28
+ from typing import Any, Callable, Coroutine, Optional, Type, TypeVar
29
+
30
+ import psutil
31
+
32
+ from .._utils import TypeDispatcher, create_actor_ref, to_binary
33
+ from ..api import Actor
34
+ from ..core import ActorRef, BufferRef, FileObjectRef, register_local_pool
35
+ from ..debug import debug_async_timeout, record_message_trace
36
+ from ..errors import (
37
+ ActorAlreadyExist,
38
+ ActorNotExist,
39
+ CannotCancelTask,
40
+ SendMessageFailed,
41
+ ServerClosed,
42
+ )
43
+ from ..metrics import init_metrics
44
+ from ..utils import implements, is_zero_ip, register_asyncio_task_timeout_detector
45
+ from .allocate_strategy import AddressSpecified, allocated_type
46
+ from .communication import (
47
+ Channel,
48
+ Server,
49
+ UCXChannel,
50
+ gen_local_address,
51
+ get_server_type,
52
+ )
53
+ from .communication.errors import ChannelClosed
54
+ from .config import ActorPoolConfig
55
+ from .core import ActorCaller, ResultMessageType
56
+ from .message import (
57
+ DEFAULT_PROTOCOL,
58
+ ActorRefMessage,
59
+ CancelMessage,
60
+ ControlMessage,
61
+ ControlMessageType,
62
+ CreateActorMessage,
63
+ DestroyActorMessage,
64
+ ErrorMessage,
65
+ HasActorMessage,
66
+ MessageType,
67
+ ResultMessage,
68
+ SendMessage,
69
+ TellMessage,
70
+ _MessageBase,
71
+ new_message_id,
72
+ )
73
+ from .router import Router
74
+
75
+ logger = logging.getLogger(__name__)
76
+
77
+
78
+ @contextlib.contextmanager
79
+ def _disable_log_temporally():
80
+ if os.getenv("CUDA_VISIBLE_DEVICES") == "-1":
81
+ # disable logging when CUDA_VISIBLE_DEVICES == -1
82
+ # many logging comes from ptxcompiler may distract users
83
+ try:
84
+ logging.disable(level=logging.ERROR)
85
+ yield
86
+ finally:
87
+ logging.disable(level=logging.NOTSET)
88
+ else:
89
+ yield
90
+
91
+
92
+ class _ErrorProcessor:
93
+ def __init__(self, address: str, message_id: bytes, protocol):
94
+ self._address = address
95
+ self._message_id = message_id
96
+ self._protocol = protocol
97
+ self.result = None
98
+
99
+ def __enter__(self):
100
+ return self
101
+
102
+ def __exit__(self, exc_type, exc_val, exc_tb):
103
+ if self.result is None:
104
+ self.result = ErrorMessage(
105
+ self._message_id,
106
+ self._address,
107
+ os.getpid(),
108
+ exc_type,
109
+ exc_val,
110
+ exc_tb,
111
+ protocol=self._protocol,
112
+ )
113
+ return True
114
+
115
+
116
+ def _register_message_handler(pool_type: Type["AbstractActorPool"]):
117
+ pool_type._message_handler = dict()
118
+ for message_type, handler in [
119
+ (MessageType.create_actor, pool_type.create_actor),
120
+ (MessageType.destroy_actor, pool_type.destroy_actor),
121
+ (MessageType.has_actor, pool_type.has_actor),
122
+ (MessageType.actor_ref, pool_type.actor_ref),
123
+ (MessageType.send, pool_type.send),
124
+ (MessageType.tell, pool_type.tell),
125
+ (MessageType.cancel, pool_type.cancel),
126
+ (MessageType.control, pool_type.handle_control_command),
127
+ (MessageType.copy_to_buffers, pool_type.handle_copy_to_buffers_message),
128
+ (MessageType.copy_to_fileobjs, pool_type.handle_copy_to_fileobjs_message),
129
+ ]:
130
+ pool_type._message_handler[message_type] = handler # type: ignore
131
+ return pool_type
132
+
133
+
134
+ class AbstractActorPool(ABC):
135
+ __slots__ = (
136
+ "process_index",
137
+ "label",
138
+ "external_address",
139
+ "internal_address",
140
+ "env",
141
+ "_servers",
142
+ "_router",
143
+ "_config",
144
+ "_stopped",
145
+ "_actors",
146
+ "_caller",
147
+ "_process_messages",
148
+ "_asyncio_task_timeout_detector_task",
149
+ )
150
+
151
+ _message_handler: dict[MessageType, Callable]
152
+ _process_messages: dict[bytes, asyncio.Future | asyncio.Task | None]
153
+
154
+ def __init__(
155
+ self,
156
+ process_index: int,
157
+ label: str,
158
+ external_address: str,
159
+ internal_address: str,
160
+ env: dict,
161
+ router: Router,
162
+ config: ActorPoolConfig,
163
+ servers: list[Server],
164
+ ):
165
+ # register local pool for local actor lookup.
166
+ # The pool is weakrefed, so we don't need to unregister it.
167
+ if not is_zero_ip(external_address):
168
+ # Only register_local_pool when we listen on non-zero ip (because all-zero ip is wildcard address),
169
+ # avoid mistaken with another remote service listen on non-zero ip with the same port.
170
+ register_local_pool(external_address, self)
171
+ self.process_index = process_index
172
+ self.label = label
173
+ self.external_address = external_address
174
+ self.internal_address = internal_address
175
+ self.env = env
176
+ self._router = router
177
+ self._config = config
178
+ self._servers = servers
179
+
180
+ self._stopped = asyncio.Event()
181
+
182
+ # states
183
+ # actor id -> actor
184
+ self._actors: dict[bytes, Actor] = dict()
185
+ # message id -> future
186
+ self._process_messages = dict()
187
+
188
+ # manage async actor callers
189
+ self._caller = ActorCaller()
190
+ self._asyncio_task_timeout_detector_task = (
191
+ register_asyncio_task_timeout_detector()
192
+ )
193
+ # init metrics
194
+ metric_configs = self._config.get_metric_configs()
195
+ metric_backend = metric_configs.get("backend")
196
+ init_metrics(metric_backend, config=metric_configs.get(metric_backend))
197
+
198
+ @property
199
+ def router(self):
200
+ return self._router
201
+
202
+ @abstractmethod
203
+ async def create_actor(self, message: CreateActorMessage) -> ResultMessageType:
204
+ """
205
+ Create an actor.
206
+
207
+ Parameters
208
+ ----------
209
+ message: CreateActorMessage
210
+ message to create an actor.
211
+
212
+ Returns
213
+ -------
214
+ result_message
215
+ result or error message.
216
+ """
217
+
218
+ @abstractmethod
219
+ async def has_actor(self, message: HasActorMessage) -> ResultMessage:
220
+ """
221
+ Check if an actor exists or not.
222
+
223
+ Parameters
224
+ ----------
225
+ message: HasActorMessage
226
+ message
227
+
228
+ Returns
229
+ -------
230
+ result_message
231
+ result message contains if an actor exists or not.
232
+ """
233
+
234
+ @abstractmethod
235
+ async def destroy_actor(self, message: DestroyActorMessage) -> ResultMessageType:
236
+ """
237
+ Destroy an actor.
238
+
239
+ Parameters
240
+ ----------
241
+ message: DestroyActorMessage
242
+ message to destroy an actor.
243
+
244
+ Returns
245
+ -------
246
+ result_message
247
+ result or error message.
248
+ """
249
+
250
+ @abstractmethod
251
+ async def actor_ref(self, message: ActorRefMessage) -> ResultMessageType:
252
+ """
253
+ Get an actor's ref.
254
+
255
+ Parameters
256
+ ----------
257
+ message: ActorRefMessage
258
+ message to get an actor's ref.
259
+
260
+ Returns
261
+ -------
262
+ result_message
263
+ result or error message.
264
+ """
265
+
266
+ @abstractmethod
267
+ async def send(self, message: SendMessage) -> ResultMessageType:
268
+ """
269
+ Send a message to some actor.
270
+
271
+ Parameters
272
+ ----------
273
+ message: SendMessage
274
+ Message to send.
275
+
276
+ Returns
277
+ -------
278
+ result_message
279
+ result or error message.
280
+ """
281
+
282
+ @abstractmethod
283
+ async def tell(self, message: TellMessage) -> ResultMessageType:
284
+ """
285
+ Tell message to some actor.
286
+
287
+ Parameters
288
+ ----------
289
+ message: TellMessage
290
+ Message to tell.
291
+
292
+ Returns
293
+ -------
294
+ result_message
295
+ result or error message.
296
+ """
297
+
298
+ @abstractmethod
299
+ async def cancel(self, message: CancelMessage) -> ResultMessageType:
300
+ """
301
+ Cancel message that sent
302
+
303
+ Parameters
304
+ ----------
305
+ message: CancelMessage
306
+ Cancel message.
307
+
308
+ Returns
309
+ -------
310
+ result_message
311
+ result or error message
312
+ """
313
+
314
+ def _sync_pool_config(self, actor_pool_config: ActorPoolConfig):
315
+ self._config = actor_pool_config
316
+ # remove router from global one
317
+ global_router = Router.get_instance()
318
+ global_router.remove_router(self._router) # type: ignore
319
+ # update router
320
+ self._router.set_mapping(actor_pool_config.external_to_internal_address_map)
321
+ # update global router
322
+ global_router.add_router(self._router) # type: ignore
323
+
324
+ async def handle_control_command(
325
+ self, message: ControlMessage
326
+ ) -> ResultMessageType:
327
+ """
328
+ Handle control command.
329
+
330
+ Parameters
331
+ ----------
332
+ message: ControlMessage
333
+ Control message.
334
+
335
+ Returns
336
+ -------
337
+ result_message
338
+ result or error message.
339
+ """
340
+ with _ErrorProcessor(
341
+ self.external_address, message.message_id, protocol=message.protocol
342
+ ) as processor:
343
+ content: bool | ActorPoolConfig = True
344
+ if message.control_message_type == ControlMessageType.stop:
345
+ await self.stop()
346
+ elif message.control_message_type == ControlMessageType.sync_config:
347
+ self._sync_pool_config(message.content)
348
+ elif message.control_message_type == ControlMessageType.get_config:
349
+ if message.content == "main_pool_address":
350
+ main_process_index = self._config.get_process_indexes()[0]
351
+ content = self._config.get_pool_config(main_process_index)[
352
+ "external_address"
353
+ ][0]
354
+ else:
355
+ content = self._config
356
+ else: # pragma: no cover
357
+ raise TypeError(
358
+ f"Unable to handle control message "
359
+ f"with type {message.control_message_type}"
360
+ )
361
+ processor.result = ResultMessage(
362
+ message.message_id, content, protocol=message.protocol
363
+ )
364
+
365
+ return processor.result
366
+
367
+ async def _run_coro(self, message_id: bytes, coro: Coroutine):
368
+ self._process_messages[message_id] = asyncio.tasks.current_task()
369
+ try:
370
+ return await coro
371
+ finally:
372
+ self._process_messages.pop(message_id, None)
373
+
374
+ async def _send_channel(
375
+ self, result: _MessageBase, channel: Channel, resend_failure: bool = True
376
+ ):
377
+ try:
378
+ await channel.send(result)
379
+ except (ChannelClosed, ConnectionResetError):
380
+ if not self._stopped.is_set():
381
+ raise
382
+ except Exception as ex:
383
+ logger.exception(
384
+ "Error when sending message %s from %s to %s",
385
+ result.message_id.hex(),
386
+ channel.local_address,
387
+ channel.dest_address,
388
+ )
389
+ if not resend_failure: # pragma: no cover
390
+ raise
391
+
392
+ with _ErrorProcessor(
393
+ self.external_address, result.message_id, result.protocol
394
+ ) as processor:
395
+ error_msg = (
396
+ f"Error when sending message {result.message_id.hex()}. "
397
+ f"Caused by {ex!r}. "
398
+ )
399
+ if isinstance(result, ErrorMessage):
400
+ format_tb = "\n".join(traceback.format_tb(result.traceback))
401
+ error_msg += (
402
+ f"\nOriginal error: {result.error!r}"
403
+ f"Traceback: \n{format_tb}"
404
+ )
405
+ else:
406
+ error_msg += "See server logs for more details"
407
+ raise SendMessageFailed(error_msg) from None
408
+ await self._send_channel(processor.result, channel, resend_failure=False)
409
+
410
+ async def process_message(self, message: _MessageBase, channel: Channel):
411
+ handler = self._message_handler[message.message_type]
412
+ with _ErrorProcessor(
413
+ self.external_address, message.message_id, message.protocol
414
+ ) as processor:
415
+ # use `%.500` to avoid print too long messages
416
+ with debug_async_timeout(
417
+ "process_message_timeout",
418
+ "Process message %.500s of channel %s timeout.",
419
+ message,
420
+ channel,
421
+ ):
422
+ processor.result = await self._run_coro(
423
+ message.message_id, handler(self, message)
424
+ )
425
+
426
+ await self._send_channel(processor.result, channel)
427
+
428
+ async def call(self, dest_address: str, message: _MessageBase) -> ResultMessageType:
429
+ return await self._caller.call(self._router, dest_address, message) # type: ignore
430
+
431
+ @staticmethod
432
+ def _parse_config(config: dict, kw: dict) -> dict:
433
+ actor_pool_config: ActorPoolConfig = config.pop("actor_pool_config")
434
+ kw["config"] = actor_pool_config
435
+ kw["process_index"] = process_index = config.pop("process_index")
436
+ curr_pool_config = actor_pool_config.get_pool_config(process_index)
437
+ kw["label"] = curr_pool_config["label"]
438
+ external_addresses = curr_pool_config["external_address"]
439
+ kw["external_address"] = external_addresses[0]
440
+ kw["internal_address"] = curr_pool_config["internal_address"]
441
+ kw["router"] = Router(
442
+ external_addresses,
443
+ gen_local_address(process_index),
444
+ actor_pool_config.external_to_internal_address_map,
445
+ comm_config=actor_pool_config.get_comm_config(),
446
+ )
447
+ kw["env"] = curr_pool_config["env"]
448
+
449
+ if config: # pragma: no cover
450
+ raise TypeError(
451
+ f"Creating pool got unexpected " f'arguments: {",".join(config)}'
452
+ )
453
+
454
+ return kw
455
+
456
+ @classmethod
457
+ @abstractmethod
458
+ async def create(cls, config: dict) -> "AbstractActorPool":
459
+ """
460
+ Create an actor pool.
461
+
462
+ Parameters
463
+ ----------
464
+ config: dict
465
+ configurations.
466
+
467
+ Returns
468
+ -------
469
+ actor_pool:
470
+ Actor pool.
471
+ """
472
+
473
+ async def start(self):
474
+ if self._stopped.is_set():
475
+ raise RuntimeError("pool has been stopped, cannot start again")
476
+ start_servers = [server.start() for server in self._servers]
477
+ await asyncio.gather(*start_servers)
478
+
479
+ async def join(self, timeout: float | None = None):
480
+ wait_stopped = asyncio.create_task(self._stopped.wait())
481
+
482
+ try:
483
+ await asyncio.wait_for(wait_stopped, timeout=timeout)
484
+ except (futures.TimeoutError, asyncio.TimeoutError): # pragma: no cover
485
+ wait_stopped.cancel()
486
+
487
+ async def stop(self):
488
+ try:
489
+ # clean global router
490
+ router = Router.get_instance()
491
+ if router is not None:
492
+ router.remove_router(self._router)
493
+ stop_tasks = []
494
+ # stop all servers
495
+ stop_tasks.extend([server.stop() for server in self._servers])
496
+ # stop all clients
497
+ stop_tasks.append(self._caller.stop())
498
+ await asyncio.gather(*stop_tasks)
499
+
500
+ self._servers = []
501
+ if self._asyncio_task_timeout_detector_task: # pragma: no cover
502
+ self._asyncio_task_timeout_detector_task.cancel()
503
+ finally:
504
+ self._stopped.set()
505
+
506
+ async def handle_copy_to_buffers_message(self, message) -> ResultMessage:
507
+ for addr, uid, start, _len, data in message.content:
508
+ buffer = BufferRef.get_buffer(BufferRef(addr, uid))
509
+ buffer[start : start + _len] = data
510
+ return ResultMessage(message_id=message.message_id, result=True)
511
+
512
+ async def handle_copy_to_fileobjs_message(self, message) -> ResultMessage:
513
+ for addr, uid, data in message.content:
514
+ file_obj = FileObjectRef.get_local_file_object(FileObjectRef(addr, uid))
515
+ await file_obj.write(data)
516
+ return ResultMessage(message_id=message.message_id, result=True)
517
+
518
+ @property
519
+ def stopped(self) -> bool:
520
+ return self._stopped.is_set()
521
+
522
+ async def _handle_ucx_meta_message(
523
+ self, message: _MessageBase, channel: Channel
524
+ ) -> bool:
525
+ if (
526
+ isinstance(message, ControlMessage)
527
+ and message.message_type == MessageType.control
528
+ and message.control_message_type == ControlMessageType.switch_to_copy_to
529
+ and isinstance(channel, UCXChannel)
530
+ ):
531
+ with _ErrorProcessor(
532
+ self.external_address, message.message_id, message.protocol
533
+ ) as processor:
534
+ # use `%.500` to avoid print too long messages
535
+ with debug_async_timeout(
536
+ "process_message_timeout",
537
+ "Process message %.500s of channel %s timeout.",
538
+ message,
539
+ channel,
540
+ ):
541
+ buffers = [
542
+ BufferRef.get_buffer(BufferRef(addr, uid))
543
+ for addr, uid in message.content
544
+ ]
545
+ await channel.recv_buffers(buffers)
546
+ processor.result = ResultMessage(
547
+ message_id=message.message_id, result=True
548
+ )
549
+ asyncio.create_task(self._send_channel(processor.result, channel))
550
+ return True
551
+ return False
552
+
553
+ async def on_new_channel(self, channel: Channel):
554
+ while not self._stopped.is_set():
555
+ try:
556
+ message = await channel.recv()
557
+ except EOFError:
558
+ # no data to read, check channel
559
+ try:
560
+ await channel.close()
561
+ except (ConnectionError, EOFError):
562
+ # close failed, ignore
563
+ pass
564
+ return
565
+ if await self._handle_ucx_meta_message(message, channel):
566
+ continue
567
+ asyncio.create_task(self.process_message(message, channel))
568
+ # delete to release the reference of message
569
+ del message
570
+ await asyncio.sleep(0)
571
+
572
+ async def __aenter__(self):
573
+ await self.start()
574
+ return self
575
+
576
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
577
+ await self.stop()
578
+
579
+
580
+ class ActorPoolBase(AbstractActorPool, metaclass=ABCMeta):
581
+ __slots__ = ()
582
+
583
+ @implements(AbstractActorPool.create_actor)
584
+ async def create_actor(self, message: CreateActorMessage) -> ResultMessageType:
585
+ with _ErrorProcessor(
586
+ self.external_address, message.message_id, message.protocol
587
+ ) as processor:
588
+ actor_id = message.actor_id
589
+ if actor_id in self._actors:
590
+ raise ActorAlreadyExist(
591
+ f"Actor {actor_id!r} already exist, cannot create"
592
+ )
593
+
594
+ actor = message.actor_cls(*message.args, **message.kwargs)
595
+ actor.uid = actor_id
596
+ actor.address = address = self.external_address
597
+ self._actors[actor_id] = actor
598
+ await self._run_coro(message.message_id, actor.__post_create__())
599
+
600
+ result = ActorRef(address, actor_id)
601
+ # ensemble result message
602
+ processor.result = ResultMessage(
603
+ message.message_id, result, protocol=message.protocol
604
+ )
605
+ return processor.result
606
+
607
+ @implements(AbstractActorPool.has_actor)
608
+ async def has_actor(self, message: HasActorMessage) -> ResultMessage:
609
+ result = ResultMessage(
610
+ message.message_id,
611
+ message.actor_ref.uid in self._actors,
612
+ protocol=message.protocol,
613
+ )
614
+ return result
615
+
616
+ @implements(AbstractActorPool.destroy_actor)
617
+ async def destroy_actor(self, message: DestroyActorMessage) -> ResultMessageType:
618
+ with _ErrorProcessor(
619
+ self.external_address, message.message_id, message.protocol
620
+ ) as processor:
621
+ actor_id = message.actor_ref.uid
622
+ try:
623
+ actor = self._actors[actor_id]
624
+ except KeyError:
625
+ raise ActorNotExist(f"Actor {actor_id} does not exist")
626
+ await self._run_coro(message.message_id, actor.__pre_destroy__())
627
+ del self._actors[actor_id]
628
+
629
+ processor.result = ResultMessage(
630
+ message.message_id, actor_id, protocol=message.protocol
631
+ )
632
+ return processor.result
633
+
634
+ @implements(AbstractActorPool.actor_ref)
635
+ async def actor_ref(self, message: ActorRefMessage) -> ResultMessageType:
636
+ with _ErrorProcessor(
637
+ self.external_address, message.message_id, message.protocol
638
+ ) as processor:
639
+ actor_id = message.actor_ref.uid
640
+ if actor_id not in self._actors:
641
+ raise ActorNotExist(f"Actor {actor_id} does not exist")
642
+ result = ResultMessage(
643
+ message.message_id,
644
+ ActorRef(self.external_address, actor_id),
645
+ protocol=message.protocol,
646
+ )
647
+ processor.result = result
648
+ return processor.result
649
+
650
+ @implements(AbstractActorPool.send)
651
+ async def send(self, message: SendMessage) -> ResultMessageType:
652
+ with _ErrorProcessor(
653
+ self.external_address, message.message_id, message.protocol
654
+ ) as processor, record_message_trace(message):
655
+ actor_id = message.actor_ref.uid
656
+ if actor_id not in self._actors:
657
+ raise ActorNotExist(f"Actor {actor_id} does not exist")
658
+ coro = self._actors[actor_id].__on_receive__(message.content)
659
+ result = await self._run_coro(message.message_id, coro)
660
+ processor.result = ResultMessage(
661
+ message.message_id,
662
+ result,
663
+ protocol=message.protocol,
664
+ profiling_context=message.profiling_context,
665
+ )
666
+ return processor.result
667
+
668
+ @implements(AbstractActorPool.tell)
669
+ async def tell(self, message: TellMessage) -> ResultMessageType:
670
+ with _ErrorProcessor(
671
+ self.external_address, message.message_id, message.protocol
672
+ ) as processor:
673
+ actor_id = message.actor_ref.uid
674
+ if actor_id not in self._actors: # pragma: no cover
675
+ raise ActorNotExist(f"Actor {actor_id} does not exist")
676
+ call = self._actors[actor_id].__on_receive__(message.content)
677
+ # asynchronously run, tell does not care about result
678
+ asyncio.create_task(call)
679
+ await asyncio.sleep(0)
680
+ processor.result = ResultMessage(
681
+ message.message_id,
682
+ None,
683
+ protocol=message.protocol,
684
+ profiling_context=message.profiling_context,
685
+ )
686
+ return processor.result
687
+
688
+ @implements(AbstractActorPool.cancel)
689
+ async def cancel(self, message: CancelMessage) -> ResultMessageType:
690
+ with _ErrorProcessor(
691
+ self.external_address, message.message_id, message.protocol
692
+ ) as processor:
693
+ future = self._process_messages.get(message.cancel_message_id)
694
+ if future is None or future.done(): # pragma: no cover
695
+ raise CannotCancelTask(
696
+ "Task not exists, maybe it is done or cancelled already"
697
+ )
698
+ future.cancel()
699
+ processor.result = ResultMessage(
700
+ message.message_id, True, protocol=message.protocol
701
+ )
702
+ return processor.result
703
+
704
+ @staticmethod
705
+ def _set_global_router(router: Router):
706
+ # be cautious about setting global router
707
+ # for instance, multiple main pool may be created in the same process
708
+
709
+ # get default router or create an empty one
710
+ default_router = Router.get_instance_or_empty()
711
+ Router.set_instance(default_router)
712
+ # append this router to global
713
+ default_router.add_router(router)
714
+
715
+ @staticmethod
716
+ def _update_stored_addresses(
717
+ servers: list[Server],
718
+ raw_addresses: list[str],
719
+ actor_pool_config: ActorPoolConfig,
720
+ kw: dict,
721
+ ):
722
+ process_index = kw["process_index"]
723
+ curr_pool_config = actor_pool_config.get_pool_config(process_index)
724
+ external_addresses = curr_pool_config["external_address"]
725
+ external_address_set = set(external_addresses)
726
+
727
+ kw["servers"] = servers
728
+
729
+ new_external_addresses = [
730
+ server.address
731
+ for server, raw_address in zip(servers, raw_addresses)
732
+ if raw_address in external_address_set
733
+ ]
734
+
735
+ if external_address_set != set(new_external_addresses):
736
+ external_addresses = new_external_addresses
737
+ actor_pool_config.reset_pool_external_address(
738
+ process_index, external_addresses
739
+ )
740
+ external_addresses = curr_pool_config["external_address"]
741
+
742
+ logger.debug(
743
+ "External address of process index %s updated to %s",
744
+ process_index,
745
+ external_addresses[0],
746
+ )
747
+ if kw["internal_address"] == kw["external_address"]:
748
+ # internal address may be the same as external address in Windows
749
+ kw["internal_address"] = external_addresses[0]
750
+ kw["external_address"] = external_addresses[0]
751
+
752
+ kw["router"] = Router(
753
+ external_addresses,
754
+ gen_local_address(process_index),
755
+ actor_pool_config.external_to_internal_address_map,
756
+ comm_config=actor_pool_config.get_comm_config(),
757
+ )
758
+
759
+ @classmethod
760
+ async def _create_servers(
761
+ cls, addresses: list[str], channel_handler: Callable, config: dict
762
+ ):
763
+ assert len(set(addresses)) == len(addresses)
764
+ # create servers
765
+ create_server_tasks = []
766
+ for addr in addresses:
767
+ server_type = get_server_type(addr)
768
+ extra_config = server_type.parse_config(config)
769
+ server_config = dict(address=addr, handle_channel=channel_handler)
770
+ server_config.update(extra_config)
771
+ task = asyncio.create_task(server_type.create(server_config))
772
+ create_server_tasks.append(task)
773
+
774
+ await asyncio.gather(*create_server_tasks)
775
+ return [f.result() for f in create_server_tasks]
776
+
777
+ @classmethod
778
+ @implements(AbstractActorPool.create)
779
+ async def create(cls, config: dict) -> "ActorPoolType":
780
+ config = config.copy()
781
+ kw: dict[str, Any] = dict()
782
+ cls._parse_config(config, kw)
783
+ process_index: int = kw["process_index"]
784
+ actor_pool_config = kw["config"] # type: ActorPoolConfig
785
+ cur_pool_config = actor_pool_config.get_pool_config(process_index)
786
+ external_addresses = cur_pool_config["external_address"]
787
+ internal_address = kw["internal_address"]
788
+
789
+ # import predefined modules
790
+ modules = cur_pool_config["modules"] or []
791
+ for mod in modules:
792
+ __import__(mod, globals(), locals(), [])
793
+ # make sure all lazy imports loaded
794
+ with _disable_log_temporally():
795
+ TypeDispatcher.reload_all_lazy_handlers()
796
+
797
+ def handle_channel(channel):
798
+ return pool.on_new_channel(channel)
799
+
800
+ # create servers
801
+ server_addresses = list(external_addresses)
802
+ if internal_address:
803
+ server_addresses.append(internal_address)
804
+ server_addresses.append(gen_local_address(process_index))
805
+ server_addresses = sorted(set(server_addresses))
806
+ servers = await cls._create_servers(
807
+ server_addresses, handle_channel, actor_pool_config.get_comm_config()
808
+ )
809
+ cls._update_stored_addresses(servers, server_addresses, actor_pool_config, kw)
810
+
811
+ # set default router
812
+ # actor context would be able to use exact client
813
+ cls._set_global_router(kw["router"])
814
+
815
+ # create pool
816
+ pool = cls(**kw)
817
+ return pool # type: ignore
818
+
819
+
820
+ ActorPoolType = TypeVar("ActorPoolType", bound=AbstractActorPool)
821
+ MainActorPoolType = TypeVar("MainActorPoolType", bound="MainActorPoolBase")
822
+ SubProcessHandle = multiprocessing.Process
823
+
824
+
825
+ class SubActorPoolBase(ActorPoolBase):
826
+ __slots__ = ("_main_address", "_watch_main_pool_task")
827
+ _watch_main_pool_task: Optional[asyncio.Task]
828
+
829
+ def __init__(
830
+ self,
831
+ process_index: int,
832
+ label: str,
833
+ external_address: str,
834
+ internal_address: str,
835
+ env: dict,
836
+ router: Router,
837
+ config: ActorPoolConfig,
838
+ servers: list[Server],
839
+ main_address: str,
840
+ main_pool_pid: Optional[int],
841
+ ):
842
+ super().__init__(
843
+ process_index,
844
+ label,
845
+ external_address,
846
+ internal_address,
847
+ env,
848
+ router,
849
+ config,
850
+ servers,
851
+ )
852
+ self._main_address = main_address
853
+ if main_pool_pid:
854
+ self._watch_main_pool_task = asyncio.create_task(
855
+ self._watch_main_pool(main_pool_pid)
856
+ )
857
+ else:
858
+ self._watch_main_pool_task = None
859
+
860
+ async def _watch_main_pool(self, main_pool_pid: int):
861
+ main_process = psutil.Process(main_pool_pid)
862
+ while not self.stopped:
863
+ try:
864
+ await asyncio.to_thread(main_process.status)
865
+ await asyncio.sleep(0.1)
866
+ continue
867
+ except (psutil.NoSuchProcess, ProcessLookupError, asyncio.CancelledError):
868
+ # main pool died
869
+ break
870
+
871
+ if not self.stopped:
872
+ await self.stop()
873
+
874
+ async def notify_main_pool_to_destroy(
875
+ self, message: DestroyActorMessage
876
+ ): # pragma: no cover
877
+ await self.call(self._main_address, message)
878
+
879
+ async def notify_main_pool_to_create(self, message: CreateActorMessage):
880
+ reg_message = ControlMessage(
881
+ new_message_id(),
882
+ self.external_address,
883
+ ControlMessageType.add_sub_pool_actor,
884
+ (self.external_address, message.allocate_strategy, message),
885
+ )
886
+ await self.call(self._main_address, reg_message)
887
+
888
+ @implements(AbstractActorPool.create_actor)
889
+ async def create_actor(self, message: CreateActorMessage) -> ResultMessageType:
890
+ result = await super().create_actor(message)
891
+ if not message.from_main:
892
+ await self.notify_main_pool_to_create(message)
893
+ return result
894
+
895
+ @implements(AbstractActorPool.actor_ref)
896
+ async def actor_ref(self, message: ActorRefMessage) -> ResultMessageType:
897
+ result = await super().actor_ref(message)
898
+ if isinstance(result, ErrorMessage):
899
+ # need a new message id to call main actor
900
+ main_message = ActorRefMessage(
901
+ new_message_id(),
902
+ create_actor_ref(self._main_address, message.actor_ref.uid),
903
+ )
904
+ result = await self.call(self._main_address, main_message)
905
+ # rewrite to message_id of the original request
906
+ result.message_id = message.message_id
907
+ return result
908
+
909
+ @implements(AbstractActorPool.destroy_actor)
910
+ async def destroy_actor(self, message: DestroyActorMessage) -> ResultMessageType:
911
+ result = await super().destroy_actor(message)
912
+ if isinstance(result, ResultMessage) and not message.from_main:
913
+ # sync back to main actor pool
914
+ await self.notify_main_pool_to_destroy(message)
915
+ return result
916
+
917
+ @implements(AbstractActorPool.handle_control_command)
918
+ async def handle_control_command(
919
+ self, message: ControlMessage
920
+ ) -> ResultMessageType:
921
+ if message.control_message_type == ControlMessageType.sync_config:
922
+ self._main_address = message.address
923
+ return await super().handle_control_command(message)
924
+
925
+ @staticmethod
926
+ def _parse_config(config: dict, kw: dict) -> dict:
927
+ main_pool_pid = config.pop("main_pool_pid", None)
928
+ kw = AbstractActorPool._parse_config(config, kw)
929
+ pool_config: ActorPoolConfig = kw["config"]
930
+ main_process_index = pool_config.get_process_indexes()[0]
931
+ kw["main_address"] = pool_config.get_pool_config(main_process_index)[
932
+ "external_address"
933
+ ][0]
934
+ kw["main_pool_pid"] = main_pool_pid
935
+ return kw
936
+
937
+ async def stop(self):
938
+ await super().stop()
939
+ if self._watch_main_pool_task:
940
+ self._watch_main_pool_task.cancel()
941
+ await self._watch_main_pool_task
942
+
943
+
944
+ class MainActorPoolBase(ActorPoolBase):
945
+ __slots__ = (
946
+ "_subprocess_start_method",
947
+ "_allocated_actors",
948
+ "sub_actor_pool_manager",
949
+ "_auto_recover",
950
+ "_monitor_task",
951
+ "_on_process_down",
952
+ "_on_process_recover",
953
+ "_recover_events",
954
+ "_allocation_lock",
955
+ "sub_processes",
956
+ )
957
+
958
+ def __init__(
959
+ self,
960
+ process_index: int,
961
+ label: str,
962
+ external_address: str,
963
+ internal_address: str,
964
+ env: dict,
965
+ router: Router,
966
+ config: ActorPoolConfig,
967
+ servers: list[Server],
968
+ subprocess_start_method: str | None = None,
969
+ auto_recover: str | bool = "actor",
970
+ on_process_down: Callable[[MainActorPoolType, str], None] | None = None,
971
+ on_process_recover: Callable[[MainActorPoolType, str], None] | None = None,
972
+ ):
973
+ super().__init__(
974
+ process_index,
975
+ label,
976
+ external_address,
977
+ internal_address,
978
+ env,
979
+ router,
980
+ config,
981
+ servers,
982
+ )
983
+ self._subprocess_start_method = subprocess_start_method
984
+
985
+ # auto recovering
986
+ self._auto_recover = auto_recover
987
+ self._monitor_task: Optional[asyncio.Task] = None
988
+ self._on_process_down = on_process_down
989
+ self._on_process_recover = on_process_recover
990
+ self._recover_events: dict[str, asyncio.Event] = dict()
991
+
992
+ # states
993
+ self._allocated_actors: allocated_type = {
994
+ addr: dict() for addr in self._config.get_external_addresses()
995
+ }
996
+ self._allocation_lock = threading.Lock()
997
+
998
+ self.sub_processes: dict[str, SubProcessHandle] = dict()
999
+
1000
+ _process_index_gen = itertools.count()
1001
+
1002
+ @classmethod
1003
+ def process_index_gen(cls, address):
1004
+ # make sure different processes does not share process indexes
1005
+ pid = os.getpid()
1006
+ for idx in cls._process_index_gen:
1007
+ yield pid << 16 + idx
1008
+
1009
+ @property
1010
+ def _sub_processes(self):
1011
+ return self.sub_processes
1012
+
1013
+ @implements(AbstractActorPool.create_actor)
1014
+ async def create_actor(self, message: CreateActorMessage) -> ResultMessageType:
1015
+ with _ErrorProcessor(
1016
+ address=self.external_address,
1017
+ message_id=message.message_id,
1018
+ protocol=message.protocol,
1019
+ ) as processor:
1020
+ allocate_strategy = message.allocate_strategy
1021
+ with self._allocation_lock:
1022
+ # get allocated address according to corresponding strategy
1023
+ address = allocate_strategy.get_allocated_address(
1024
+ self._config, self._allocated_actors
1025
+ )
1026
+ # set placeholder to make sure this label is occupied
1027
+ self._allocated_actors[address][None] = (allocate_strategy, message)
1028
+ if address == self.external_address:
1029
+ # creating actor on main actor pool
1030
+ result = await super().create_actor(message)
1031
+ if isinstance(result, ResultMessage):
1032
+ self._allocated_actors[self.external_address][result.result] = (
1033
+ allocate_strategy,
1034
+ message,
1035
+ )
1036
+ processor.result = result
1037
+ else:
1038
+ # creating actor on sub actor pool
1039
+ # rewrite allocate strategy to AddressSpecified
1040
+ new_allocate_strategy = AddressSpecified(address)
1041
+ new_create_actor_message = CreateActorMessage(
1042
+ message.message_id,
1043
+ message.actor_cls,
1044
+ message.actor_id,
1045
+ message.args,
1046
+ message.kwargs,
1047
+ allocate_strategy=new_allocate_strategy,
1048
+ from_main=True,
1049
+ protocol=message.protocol,
1050
+ message_trace=message.message_trace,
1051
+ )
1052
+ result = await self.call(address, new_create_actor_message)
1053
+ if isinstance(result, ResultMessage):
1054
+ self._allocated_actors[address][result.result] = (
1055
+ allocate_strategy,
1056
+ new_create_actor_message,
1057
+ )
1058
+ processor.result = result
1059
+
1060
+ # revert placeholder
1061
+ self._allocated_actors[address].pop(None, None)
1062
+
1063
+ return processor.result
1064
+
1065
+ @implements(AbstractActorPool.has_actor)
1066
+ async def has_actor(self, message: HasActorMessage) -> ResultMessage:
1067
+ actor_ref = message.actor_ref
1068
+ # lookup allocated
1069
+ for address, item in self._allocated_actors.items():
1070
+ ref = create_actor_ref(address, to_binary(actor_ref.uid))
1071
+ if ref in item:
1072
+ return ResultMessage(
1073
+ message.message_id, True, protocol=message.protocol
1074
+ )
1075
+
1076
+ return ResultMessage(message.message_id, False, protocol=message.protocol)
1077
+
1078
+ @implements(AbstractActorPool.destroy_actor)
1079
+ async def destroy_actor(self, message: DestroyActorMessage) -> ResultMessageType:
1080
+ actor_ref_message = ActorRefMessage(
1081
+ message.message_id, message.actor_ref, protocol=message.protocol
1082
+ )
1083
+ result = await self.actor_ref(actor_ref_message)
1084
+ if not isinstance(result, ResultMessage):
1085
+ return result
1086
+ real_actor_ref = result.result
1087
+ if real_actor_ref.address == self.external_address:
1088
+ result = await super().destroy_actor(message)
1089
+ if result.message_type == MessageType.error:
1090
+ return result
1091
+ del self._allocated_actors[self.external_address][real_actor_ref]
1092
+ return ResultMessage(
1093
+ message.message_id, real_actor_ref.uid, protocol=message.protocol
1094
+ )
1095
+ # remove allocated actor ref
1096
+ self._allocated_actors[real_actor_ref.address].pop(real_actor_ref, None)
1097
+ new_destroy_message = DestroyActorMessage(
1098
+ message.message_id,
1099
+ real_actor_ref,
1100
+ from_main=True,
1101
+ protocol=message.protocol,
1102
+ )
1103
+ return await self.call(real_actor_ref.address, new_destroy_message)
1104
+
1105
+ @implements(AbstractActorPool.send)
1106
+ async def send(self, message: SendMessage) -> ResultMessageType:
1107
+ if message.actor_ref.uid in self._actors:
1108
+ return await super().send(message)
1109
+ actor_ref_message = ActorRefMessage(
1110
+ message.message_id, message.actor_ref, protocol=message.protocol
1111
+ )
1112
+ result = await self.actor_ref(actor_ref_message)
1113
+ if not isinstance(result, ResultMessage):
1114
+ return result
1115
+ actor_ref = result.result
1116
+ new_send_message = SendMessage(
1117
+ message.message_id,
1118
+ actor_ref,
1119
+ message.content,
1120
+ protocol=message.protocol,
1121
+ message_trace=message.message_trace,
1122
+ )
1123
+ return await self.call(actor_ref.address, new_send_message)
1124
+
1125
+ @implements(AbstractActorPool.tell)
1126
+ async def tell(self, message: TellMessage) -> ResultMessageType:
1127
+ if message.actor_ref.uid in self._actors:
1128
+ return await super().tell(message)
1129
+ actor_ref_message = ActorRefMessage(
1130
+ message.message_id, message.actor_ref, protocol=message.protocol
1131
+ )
1132
+ result = await self.actor_ref(actor_ref_message)
1133
+ if not isinstance(result, ResultMessage):
1134
+ return result
1135
+ actor_ref = result.result
1136
+ new_tell_message = TellMessage(
1137
+ message.message_id,
1138
+ actor_ref,
1139
+ message.content,
1140
+ protocol=message.protocol,
1141
+ message_trace=message.message_trace,
1142
+ )
1143
+ return await self.call(actor_ref.address, new_tell_message)
1144
+
1145
+ @implements(AbstractActorPool.actor_ref)
1146
+ async def actor_ref(self, message: ActorRefMessage) -> ResultMessageType:
1147
+ actor_ref = message.actor_ref
1148
+ actor_ref.uid = to_binary(actor_ref.uid)
1149
+ if actor_ref.address == self.external_address and actor_ref.uid in self._actors:
1150
+ return ResultMessage(
1151
+ message.message_id, actor_ref, protocol=message.protocol
1152
+ )
1153
+
1154
+ # lookup allocated
1155
+ for address, item in self._allocated_actors.items():
1156
+ ref = create_actor_ref(address, actor_ref.uid)
1157
+ if ref in item:
1158
+ return ResultMessage(message.message_id, ref, protocol=message.protocol)
1159
+
1160
+ with _ErrorProcessor(
1161
+ self.external_address, message.message_id, protocol=message.protocol
1162
+ ) as processor:
1163
+ raise ActorNotExist(
1164
+ f"Actor {actor_ref.uid} does not exist in {actor_ref.address}"
1165
+ )
1166
+
1167
+ return processor.result
1168
+
1169
+ @implements(AbstractActorPool.cancel)
1170
+ async def cancel(self, message: CancelMessage) -> ResultMessageType:
1171
+ if message.address == self.external_address:
1172
+ # local message
1173
+ return await super().cancel(message)
1174
+ # redirect to sub pool
1175
+ return await self.call(message.address, message)
1176
+
1177
+ @implements(AbstractActorPool.handle_control_command)
1178
+ async def handle_control_command(
1179
+ self, message: ControlMessage
1180
+ ) -> ResultMessageType:
1181
+ with _ErrorProcessor(
1182
+ self.external_address, message.message_id, message.protocol
1183
+ ) as processor:
1184
+ if message.address == self.external_address:
1185
+ if message.control_message_type == ControlMessageType.sync_config:
1186
+ # sync config, need to notify all sub pools
1187
+ tasks = []
1188
+ for addr in self.sub_processes:
1189
+ control_message = ControlMessage(
1190
+ new_message_id(),
1191
+ message.address,
1192
+ message.control_message_type,
1193
+ message.content,
1194
+ protocol=message.protocol,
1195
+ message_trace=message.message_trace,
1196
+ )
1197
+ tasks.append(
1198
+ asyncio.create_task(self.call(addr, control_message))
1199
+ )
1200
+ # call super
1201
+ task = asyncio.create_task(super().handle_control_command(message))
1202
+ tasks.append(task)
1203
+ await asyncio.gather(*tasks)
1204
+ processor.result = await task
1205
+ else:
1206
+ processor.result = await super().handle_control_command(message)
1207
+ elif message.control_message_type == ControlMessageType.stop:
1208
+ timeout, force = (
1209
+ message.content if message.content is not None else (None, False)
1210
+ )
1211
+ await self.stop_sub_pool(
1212
+ message.address,
1213
+ self.sub_processes[message.address],
1214
+ timeout=timeout,
1215
+ force=force,
1216
+ )
1217
+ processor.result = ResultMessage(
1218
+ message.message_id, True, protocol=message.protocol
1219
+ )
1220
+ elif message.control_message_type == ControlMessageType.wait_pool_recovered:
1221
+ if self._auto_recover and message.address not in self._recover_events:
1222
+ self._recover_events[message.address] = asyncio.Event()
1223
+
1224
+ event = self._recover_events.get(message.address, None)
1225
+ if event is not None:
1226
+ await event.wait()
1227
+ processor.result = ResultMessage(
1228
+ message.message_id, True, protocol=message.protocol
1229
+ )
1230
+ elif message.control_message_type == ControlMessageType.add_sub_pool_actor:
1231
+ address, allocate_strategy, create_message = message.content
1232
+ create_message.from_main = True
1233
+ ref = create_actor_ref(address, to_binary(create_message.actor_id))
1234
+ self._allocated_actors[address][ref] = (
1235
+ allocate_strategy,
1236
+ create_message,
1237
+ )
1238
+ processor.result = ResultMessage(
1239
+ message.message_id, True, protocol=message.protocol
1240
+ )
1241
+ else:
1242
+ processor.result = await self.call(message.address, message)
1243
+ return processor.result
1244
+
1245
+ @staticmethod
1246
+ def _parse_config(config: dict, kw: dict) -> dict:
1247
+ kw["subprocess_start_method"] = config.pop("start_method", None)
1248
+ kw["auto_recover"] = config.pop("auto_recover", "actor")
1249
+ kw["on_process_down"] = config.pop("on_process_down", None)
1250
+ kw["on_process_recover"] = config.pop("on_process_recover", None)
1251
+ kw = AbstractActorPool._parse_config(config, kw)
1252
+ return kw
1253
+
1254
+ @classmethod
1255
+ @implements(AbstractActorPool.create)
1256
+ async def create(cls, config: dict) -> MainActorPoolType:
1257
+ config = config.copy()
1258
+ actor_pool_config: ActorPoolConfig = config.get("actor_pool_config") # type: ignore
1259
+ start_method = config.get("start_method", None)
1260
+ if "process_index" not in config:
1261
+ config["process_index"] = actor_pool_config.get_process_indexes()[0]
1262
+ curr_process_index = config.get("process_index")
1263
+ old_config_addresses = set(actor_pool_config.get_external_addresses())
1264
+
1265
+ tasks = []
1266
+ subpool_process_idxes = []
1267
+ # create sub actor pools
1268
+ n_sub_pool = actor_pool_config.n_pool - 1
1269
+ if n_sub_pool > 0:
1270
+ process_indexes = actor_pool_config.get_process_indexes()
1271
+ for process_index in process_indexes:
1272
+ if process_index == curr_process_index:
1273
+ continue
1274
+ create_pool_task = asyncio.create_task(
1275
+ cls.start_sub_pool(actor_pool_config, process_index, start_method)
1276
+ )
1277
+ await asyncio.sleep(0)
1278
+ # await create_pool_task
1279
+ tasks.append(create_pool_task)
1280
+ subpool_process_idxes.append(process_index)
1281
+
1282
+ processes, ext_addresses = await cls.wait_sub_pools_ready(tasks)
1283
+ if ext_addresses:
1284
+ for process_index, ext_address in zip(subpool_process_idxes, ext_addresses):
1285
+ actor_pool_config.reset_pool_external_address(
1286
+ process_index, ext_address
1287
+ )
1288
+
1289
+ # create main actor pool
1290
+ pool: MainActorPoolType = await super().create(config)
1291
+ addresses = actor_pool_config.get_external_addresses()[1:]
1292
+
1293
+ assert len(addresses) == len(
1294
+ processes
1295
+ ), f"addresses {addresses}, processes {processes}"
1296
+ for addr, proc in zip(addresses, processes):
1297
+ pool.attach_sub_process(addr, proc)
1298
+
1299
+ new_config_addresses = set(actor_pool_config.get_external_addresses())
1300
+ if old_config_addresses != new_config_addresses:
1301
+ control_message = ControlMessage(
1302
+ message_id=new_message_id(),
1303
+ address=pool.external_address,
1304
+ control_message_type=ControlMessageType.sync_config,
1305
+ content=actor_pool_config,
1306
+ )
1307
+ await pool.handle_control_command(control_message)
1308
+
1309
+ return pool
1310
+
1311
+ async def start_monitor(self):
1312
+ if self._monitor_task is None:
1313
+ self._monitor_task = asyncio.create_task(self.monitor_sub_pools())
1314
+ return self._monitor_task
1315
+
1316
+ @implements(AbstractActorPool.stop)
1317
+ async def stop(self):
1318
+ global_router = Router.get_instance()
1319
+ if global_router is not None:
1320
+ global_router.remove_router(self._router)
1321
+
1322
+ # turn off auto recover to avoid errors
1323
+ self._auto_recover = False
1324
+ self._stopped.set()
1325
+ if self._monitor_task and not self._monitor_task.done():
1326
+ await self._monitor_task
1327
+ self._monitor_task = None
1328
+ await self.stop_sub_pools()
1329
+ await super().stop()
1330
+
1331
+ @classmethod
1332
+ @abstractmethod
1333
+ async def start_sub_pool(
1334
+ cls,
1335
+ actor_pool_config: ActorPoolConfig,
1336
+ process_index: int,
1337
+ start_method: str | None = None,
1338
+ ):
1339
+ """Start a sub actor pool"""
1340
+
1341
+ @classmethod
1342
+ @abstractmethod
1343
+ async def wait_sub_pools_ready(cls, create_pool_tasks: list[asyncio.Task]):
1344
+ """Wait all sub pools ready"""
1345
+
1346
+ def attach_sub_process(self, external_address: str, process: SubProcessHandle):
1347
+ self.sub_processes[external_address] = process
1348
+
1349
+ async def stop_sub_pools(self):
1350
+ to_stop_processes: dict[str, SubProcessHandle] = dict() # type: ignore
1351
+ for address, process in self.sub_processes.items():
1352
+ if not await self.is_sub_pool_alive(process):
1353
+ continue
1354
+ to_stop_processes[address] = process
1355
+
1356
+ tasks = []
1357
+ for address, process in to_stop_processes.items():
1358
+ tasks.append(self.stop_sub_pool(address, process))
1359
+ await asyncio.gather(*tasks)
1360
+
1361
+ async def stop_sub_pool(
1362
+ self,
1363
+ address: str,
1364
+ process: SubProcessHandle,
1365
+ timeout: float | None = None,
1366
+ force: bool = False,
1367
+ ):
1368
+ if force:
1369
+ await self.kill_sub_pool(process, force=True)
1370
+ return
1371
+
1372
+ stop_message = ControlMessage(
1373
+ new_message_id(),
1374
+ address,
1375
+ ControlMessageType.stop,
1376
+ None,
1377
+ protocol=DEFAULT_PROTOCOL,
1378
+ )
1379
+ try:
1380
+ if timeout is None:
1381
+ message = await self.call(address, stop_message)
1382
+ if isinstance(message, ErrorMessage): # pragma: no cover
1383
+ raise message.as_instanceof_cause()
1384
+ else:
1385
+ call = asyncio.create_task(self.call(address, stop_message))
1386
+ try:
1387
+ await asyncio.wait_for(call, timeout)
1388
+ except (futures.TimeoutError, asyncio.TimeoutError): # pragma: no cover
1389
+ # timeout, just let kill to finish it
1390
+ force = True
1391
+ except (ConnectionError, ServerClosed): # pragma: no cover
1392
+ # process dead maybe, ignore it
1393
+ pass
1394
+ # kill process
1395
+ await self.kill_sub_pool(process, force=force)
1396
+
1397
+ @abstractmethod
1398
+ async def kill_sub_pool(self, process: SubProcessHandle, force: bool = False):
1399
+ """Kill a sub actor pool"""
1400
+
1401
+ @abstractmethod
1402
+ async def is_sub_pool_alive(self, process: SubProcessHandle):
1403
+ """
1404
+ Check whether sub pool process is alive
1405
+ Parameters
1406
+ ----------
1407
+ process : SubProcessHandle
1408
+ sub pool process handle
1409
+ Returns
1410
+ -------
1411
+ bool
1412
+ """
1413
+
1414
+ @abstractmethod
1415
+ def recover_sub_pool(self, address):
1416
+ """Recover a sub actor pool"""
1417
+
1418
+ def process_sub_pool_lost(self, address: str):
1419
+ if self._auto_recover in (False, "process"):
1420
+ # process down, when not auto_recover
1421
+ # or only recover process, remove all created actors
1422
+ self._allocated_actors[address] = dict()
1423
+
1424
+ async def monitor_sub_pools(self):
1425
+ try:
1426
+ while not self._stopped.is_set():
1427
+ # Copy sub_processes to avoid changes during recover.
1428
+ for address, process in list(self.sub_processes.items()):
1429
+ try:
1430
+ recover_events_discovered = address in self._recover_events
1431
+ if not await self.is_sub_pool_alive(
1432
+ process
1433
+ ): # pragma: no cover
1434
+ if self._on_process_down is not None:
1435
+ self._on_process_down(self, address)
1436
+ self.process_sub_pool_lost(address)
1437
+ if self._auto_recover:
1438
+ await self.recover_sub_pool(address)
1439
+ if self._on_process_recover is not None:
1440
+ self._on_process_recover(self, address)
1441
+ if recover_events_discovered:
1442
+ event = self._recover_events.pop(address)
1443
+ event.set()
1444
+ except asyncio.CancelledError:
1445
+ raise
1446
+ except RuntimeError as ex: # pragma: no cover
1447
+ if "cannot schedule new futures" not in str(ex):
1448
+ # to silence log when process exit, otherwise it
1449
+ # will raise "RuntimeError: cannot schedule new futures
1450
+ # after interpreter shutdown".
1451
+ logger.exception("Monitor sub pool %s failed", address)
1452
+ except Exception:
1453
+ # log the exception instead of stop monitoring the
1454
+ # sub pool silently.
1455
+ logger.exception("Monitor sub pool %s failed", address)
1456
+
1457
+ # check every half second
1458
+ await asyncio.sleep(0.5)
1459
+ except asyncio.CancelledError: # pragma: no cover
1460
+ # cancelled
1461
+ return
1462
+
1463
+ @classmethod
1464
+ @abstractmethod
1465
+ def get_external_addresses(
1466
+ cls,
1467
+ address: str,
1468
+ n_process: int | None = None,
1469
+ ports: list[int] | None = None,
1470
+ schemes: list[Optional[str]] | None = None,
1471
+ ):
1472
+ """Returns external addresses for n pool processes"""
1473
+
1474
+ @classmethod
1475
+ @abstractmethod
1476
+ def gen_internal_address(
1477
+ cls, process_index: int, external_address: str | None = None
1478
+ ) -> str | None:
1479
+ """Returns internal address for pool of specified process index"""
1480
+
1481
+
1482
+ async def create_actor_pool(
1483
+ address: str,
1484
+ *,
1485
+ pool_cls: Type[MainActorPoolType] | None = None,
1486
+ n_process: int | None = None,
1487
+ labels: list[str] | None = None,
1488
+ ports: list[int] | None = None,
1489
+ envs: list[dict] | None = None,
1490
+ external_address_schemes: list[Optional[str]] | None = None,
1491
+ enable_internal_addresses: list[bool] | None = None,
1492
+ subprocess_start_method: str | None = None,
1493
+ auto_recover: str | bool = "actor",
1494
+ modules: list[str] | None = None,
1495
+ suspend_sigint: bool | None = None,
1496
+ use_uvloop: str | bool = "auto",
1497
+ logging_conf: dict | None = None,
1498
+ on_process_down: Callable[[MainActorPoolType, str], None] | None = None,
1499
+ on_process_recover: Callable[[MainActorPoolType, str], None] | None = None,
1500
+ extra_conf: dict | None = None,
1501
+ **kwargs,
1502
+ ) -> MainActorPoolType:
1503
+ if n_process is None:
1504
+ n_process = multiprocessing.cpu_count()
1505
+ if labels and len(labels) != n_process + 1:
1506
+ raise ValueError(
1507
+ f"`labels` should be of size {n_process + 1}, got {len(labels)}"
1508
+ )
1509
+ if envs and len(envs) != n_process:
1510
+ raise ValueError(f"`envs` should be of size {n_process}, got {len(envs)}")
1511
+ if external_address_schemes and len(external_address_schemes) != n_process + 1:
1512
+ raise ValueError(
1513
+ f"`external_address_schemes` should be of size {n_process + 1}, "
1514
+ f"got {len(external_address_schemes)}"
1515
+ )
1516
+ if enable_internal_addresses and len(enable_internal_addresses) != n_process + 1:
1517
+ raise ValueError(
1518
+ f"`enable_internal_addresses` should be of size {n_process + 1}, "
1519
+ f"got {len(enable_internal_addresses)}"
1520
+ )
1521
+ elif not enable_internal_addresses:
1522
+ enable_internal_addresses = [True] * (n_process + 1)
1523
+ if auto_recover is True:
1524
+ auto_recover = "actor"
1525
+ if auto_recover not in ("actor", "process", False):
1526
+ raise ValueError(
1527
+ f'`auto_recover` should be one of "actor", "process", '
1528
+ f"True or False, got {auto_recover}"
1529
+ )
1530
+ if use_uvloop == "auto":
1531
+ try:
1532
+ import uvloop # noqa: F401 # pylint: disable=unused-variable
1533
+
1534
+ use_uvloop = True
1535
+ except ImportError:
1536
+ use_uvloop = False
1537
+
1538
+ assert pool_cls is not None
1539
+ external_addresses = pool_cls.get_external_addresses(
1540
+ address, n_process=n_process, ports=ports, schemes=external_address_schemes
1541
+ )
1542
+ actor_pool_config = ActorPoolConfig()
1543
+ actor_pool_config.add_metric_configs(kwargs.get("metrics", {}))
1544
+ # add main config
1545
+ process_index_gen = pool_cls.process_index_gen(address)
1546
+ main_process_index = next(process_index_gen)
1547
+ main_internal_address = (
1548
+ pool_cls.gen_internal_address(main_process_index, external_addresses[0])
1549
+ if enable_internal_addresses[0]
1550
+ else None
1551
+ )
1552
+ actor_pool_config.add_pool_conf(
1553
+ main_process_index,
1554
+ labels[0] if labels else None,
1555
+ main_internal_address,
1556
+ external_addresses[0],
1557
+ modules=modules,
1558
+ suspend_sigint=suspend_sigint,
1559
+ use_uvloop=use_uvloop, # type: ignore
1560
+ logging_conf=logging_conf,
1561
+ kwargs=kwargs,
1562
+ )
1563
+ # add sub configs
1564
+ for i in range(n_process):
1565
+ sub_process_index = next(process_index_gen)
1566
+ internal_address = (
1567
+ pool_cls.gen_internal_address(sub_process_index, external_addresses[i + 1])
1568
+ if enable_internal_addresses[i + 1]
1569
+ else None
1570
+ )
1571
+ actor_pool_config.add_pool_conf(
1572
+ sub_process_index,
1573
+ labels[i + 1] if labels else None,
1574
+ internal_address,
1575
+ external_addresses[i + 1],
1576
+ env=envs[i] if envs else None,
1577
+ modules=modules,
1578
+ suspend_sigint=suspend_sigint,
1579
+ use_uvloop=use_uvloop, # type: ignore
1580
+ logging_conf=logging_conf,
1581
+ kwargs=kwargs,
1582
+ )
1583
+ actor_pool_config.add_comm_config(extra_conf)
1584
+
1585
+ pool: MainActorPoolType = await pool_cls.create(
1586
+ {
1587
+ "actor_pool_config": actor_pool_config,
1588
+ "process_index": main_process_index,
1589
+ "start_method": subprocess_start_method,
1590
+ "auto_recover": auto_recover,
1591
+ "on_process_down": on_process_down,
1592
+ "on_process_recover": on_process_recover,
1593
+ }
1594
+ )
1595
+ await pool.start()
1596
+ return pool