xoscar 0.3.1__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xoscar might be problematic. Click here for more details.

Files changed (80) hide show
  1. xoscar/__init__.py +60 -0
  2. xoscar/_utils.cpython-311-darwin.so +0 -0
  3. xoscar/_utils.pxd +36 -0
  4. xoscar/_utils.pyx +241 -0
  5. xoscar/_version.py +693 -0
  6. xoscar/aio/__init__.py +25 -0
  7. xoscar/aio/_threads.py +35 -0
  8. xoscar/aio/base.py +86 -0
  9. xoscar/aio/file.py +59 -0
  10. xoscar/aio/lru.py +228 -0
  11. xoscar/aio/parallelism.py +39 -0
  12. xoscar/api.py +493 -0
  13. xoscar/backend.py +67 -0
  14. xoscar/backends/__init__.py +14 -0
  15. xoscar/backends/allocate_strategy.py +160 -0
  16. xoscar/backends/communication/__init__.py +30 -0
  17. xoscar/backends/communication/base.py +315 -0
  18. xoscar/backends/communication/core.py +69 -0
  19. xoscar/backends/communication/dummy.py +242 -0
  20. xoscar/backends/communication/errors.py +20 -0
  21. xoscar/backends/communication/socket.py +375 -0
  22. xoscar/backends/communication/ucx.py +520 -0
  23. xoscar/backends/communication/utils.py +97 -0
  24. xoscar/backends/config.py +145 -0
  25. xoscar/backends/context.py +404 -0
  26. xoscar/backends/core.py +193 -0
  27. xoscar/backends/indigen/__init__.py +16 -0
  28. xoscar/backends/indigen/backend.py +51 -0
  29. xoscar/backends/indigen/driver.py +26 -0
  30. xoscar/backends/indigen/pool.py +469 -0
  31. xoscar/backends/message.cpython-311-darwin.so +0 -0
  32. xoscar/backends/message.pyx +591 -0
  33. xoscar/backends/pool.py +1593 -0
  34. xoscar/backends/router.py +207 -0
  35. xoscar/backends/test/__init__.py +16 -0
  36. xoscar/backends/test/backend.py +38 -0
  37. xoscar/backends/test/pool.py +208 -0
  38. xoscar/batch.py +256 -0
  39. xoscar/collective/__init__.py +27 -0
  40. xoscar/collective/common.py +102 -0
  41. xoscar/collective/core.py +737 -0
  42. xoscar/collective/process_group.py +687 -0
  43. xoscar/collective/utils.py +41 -0
  44. xoscar/collective/xoscar_pygloo.cpython-311-darwin.so +0 -0
  45. xoscar/constants.py +21 -0
  46. xoscar/context.cpython-311-darwin.so +0 -0
  47. xoscar/context.pxd +21 -0
  48. xoscar/context.pyx +368 -0
  49. xoscar/core.cpython-311-darwin.so +0 -0
  50. xoscar/core.pxd +50 -0
  51. xoscar/core.pyx +658 -0
  52. xoscar/debug.py +188 -0
  53. xoscar/driver.py +42 -0
  54. xoscar/errors.py +63 -0
  55. xoscar/libcpp.pxd +31 -0
  56. xoscar/metrics/__init__.py +21 -0
  57. xoscar/metrics/api.py +288 -0
  58. xoscar/metrics/backends/__init__.py +13 -0
  59. xoscar/metrics/backends/console/__init__.py +13 -0
  60. xoscar/metrics/backends/console/console_metric.py +82 -0
  61. xoscar/metrics/backends/metric.py +149 -0
  62. xoscar/metrics/backends/prometheus/__init__.py +13 -0
  63. xoscar/metrics/backends/prometheus/prometheus_metric.py +70 -0
  64. xoscar/nvutils.py +717 -0
  65. xoscar/profiling.py +260 -0
  66. xoscar/serialization/__init__.py +20 -0
  67. xoscar/serialization/aio.py +138 -0
  68. xoscar/serialization/core.cpython-311-darwin.so +0 -0
  69. xoscar/serialization/core.pxd +28 -0
  70. xoscar/serialization/core.pyx +954 -0
  71. xoscar/serialization/cuda.py +111 -0
  72. xoscar/serialization/exception.py +48 -0
  73. xoscar/serialization/numpy.py +82 -0
  74. xoscar/serialization/pyfury.py +37 -0
  75. xoscar/serialization/scipy.py +72 -0
  76. xoscar/utils.py +502 -0
  77. xoscar-0.3.1.dist-info/METADATA +225 -0
  78. xoscar-0.3.1.dist-info/RECORD +80 -0
  79. xoscar-0.3.1.dist-info/WHEEL +5 -0
  80. xoscar-0.3.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1593 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import concurrent.futures as futures
20
+ import contextlib
21
+ import itertools
22
+ import logging
23
+ import multiprocessing
24
+ import os
25
+ import threading
26
+ import traceback
27
+ from abc import ABC, ABCMeta, abstractmethod
28
+ from typing import Any, Callable, Coroutine, Optional, Type, TypeVar
29
+
30
+ import psutil
31
+
32
+ from .._utils import TypeDispatcher, create_actor_ref, to_binary
33
+ from ..api import Actor
34
+ from ..core import ActorRef, BufferRef, FileObjectRef, register_local_pool
35
+ from ..debug import debug_async_timeout, record_message_trace
36
+ from ..errors import (
37
+ ActorAlreadyExist,
38
+ ActorNotExist,
39
+ CannotCancelTask,
40
+ SendMessageFailed,
41
+ ServerClosed,
42
+ )
43
+ from ..metrics import init_metrics
44
+ from ..utils import implements, register_asyncio_task_timeout_detector
45
+ from .allocate_strategy import AddressSpecified, allocated_type
46
+ from .communication import (
47
+ Channel,
48
+ Server,
49
+ UCXChannel,
50
+ gen_local_address,
51
+ get_server_type,
52
+ )
53
+ from .communication.errors import ChannelClosed
54
+ from .config import ActorPoolConfig
55
+ from .core import ActorCaller, ResultMessageType
56
+ from .message import (
57
+ DEFAULT_PROTOCOL,
58
+ ActorRefMessage,
59
+ CancelMessage,
60
+ ControlMessage,
61
+ ControlMessageType,
62
+ CreateActorMessage,
63
+ DestroyActorMessage,
64
+ ErrorMessage,
65
+ HasActorMessage,
66
+ MessageType,
67
+ ResultMessage,
68
+ SendMessage,
69
+ TellMessage,
70
+ _MessageBase,
71
+ new_message_id,
72
+ )
73
+ from .router import Router
74
+
75
+ logger = logging.getLogger(__name__)
76
+
77
+
78
+ @contextlib.contextmanager
79
+ def _disable_log_temporally():
80
+ if os.getenv("CUDA_VISIBLE_DEVICES") == "-1":
81
+ # disable logging when CUDA_VISIBLE_DEVICES == -1
82
+ # many logging comes from ptxcompiler may distract users
83
+ try:
84
+ logging.disable(level=logging.ERROR)
85
+ yield
86
+ finally:
87
+ logging.disable(level=logging.NOTSET)
88
+ else:
89
+ yield
90
+
91
+
92
+ class _ErrorProcessor:
93
+ def __init__(self, address: str, message_id: bytes, protocol):
94
+ self._address = address
95
+ self._message_id = message_id
96
+ self._protocol = protocol
97
+ self.result = None
98
+
99
+ def __enter__(self):
100
+ return self
101
+
102
+ def __exit__(self, exc_type, exc_val, exc_tb):
103
+ if self.result is None:
104
+ self.result = ErrorMessage(
105
+ self._message_id,
106
+ self._address,
107
+ os.getpid(),
108
+ exc_type,
109
+ exc_val,
110
+ exc_tb,
111
+ protocol=self._protocol,
112
+ )
113
+ return True
114
+
115
+
116
+ def _register_message_handler(pool_type: Type["AbstractActorPool"]):
117
+ pool_type._message_handler = dict()
118
+ for message_type, handler in [
119
+ (MessageType.create_actor, pool_type.create_actor),
120
+ (MessageType.destroy_actor, pool_type.destroy_actor),
121
+ (MessageType.has_actor, pool_type.has_actor),
122
+ (MessageType.actor_ref, pool_type.actor_ref),
123
+ (MessageType.send, pool_type.send),
124
+ (MessageType.tell, pool_type.tell),
125
+ (MessageType.cancel, pool_type.cancel),
126
+ (MessageType.control, pool_type.handle_control_command),
127
+ (MessageType.copy_to_buffers, pool_type.handle_copy_to_buffers_message),
128
+ (MessageType.copy_to_fileobjs, pool_type.handle_copy_to_fileobjs_message),
129
+ ]:
130
+ pool_type._message_handler[message_type] = handler # type: ignore
131
+ return pool_type
132
+
133
+
134
+ class AbstractActorPool(ABC):
135
+ __slots__ = (
136
+ "process_index",
137
+ "label",
138
+ "external_address",
139
+ "internal_address",
140
+ "env",
141
+ "_servers",
142
+ "_router",
143
+ "_config",
144
+ "_stopped",
145
+ "_actors",
146
+ "_caller",
147
+ "_process_messages",
148
+ "_asyncio_task_timeout_detector_task",
149
+ )
150
+
151
+ _message_handler: dict[MessageType, Callable]
152
+ _process_messages: dict[bytes, asyncio.Future | asyncio.Task | None]
153
+
154
+ def __init__(
155
+ self,
156
+ process_index: int,
157
+ label: str,
158
+ external_address: str,
159
+ internal_address: str,
160
+ env: dict,
161
+ router: Router,
162
+ config: ActorPoolConfig,
163
+ servers: list[Server],
164
+ ):
165
+ # register local pool for local actor lookup.
166
+ # The pool is weakrefed, so we don't need to unregister it.
167
+ register_local_pool(external_address, self)
168
+ self.process_index = process_index
169
+ self.label = label
170
+ self.external_address = external_address
171
+ self.internal_address = internal_address
172
+ self.env = env
173
+ self._router = router
174
+ self._config = config
175
+ self._servers = servers
176
+
177
+ self._stopped = asyncio.Event()
178
+
179
+ # states
180
+ # actor id -> actor
181
+ self._actors: dict[bytes, Actor] = dict()
182
+ # message id -> future
183
+ self._process_messages = dict()
184
+
185
+ # manage async actor callers
186
+ self._caller = ActorCaller()
187
+ self._asyncio_task_timeout_detector_task = (
188
+ register_asyncio_task_timeout_detector()
189
+ )
190
+ # init metrics
191
+ metric_configs = self._config.get_metric_configs()
192
+ metric_backend = metric_configs.get("backend")
193
+ init_metrics(metric_backend, config=metric_configs.get(metric_backend))
194
+
195
+ @property
196
+ def router(self):
197
+ return self._router
198
+
199
+ @abstractmethod
200
+ async def create_actor(self, message: CreateActorMessage) -> ResultMessageType:
201
+ """
202
+ Create an actor.
203
+
204
+ Parameters
205
+ ----------
206
+ message: CreateActorMessage
207
+ message to create an actor.
208
+
209
+ Returns
210
+ -------
211
+ result_message
212
+ result or error message.
213
+ """
214
+
215
+ @abstractmethod
216
+ async def has_actor(self, message: HasActorMessage) -> ResultMessage:
217
+ """
218
+ Check if an actor exists or not.
219
+
220
+ Parameters
221
+ ----------
222
+ message: HasActorMessage
223
+ message
224
+
225
+ Returns
226
+ -------
227
+ result_message
228
+ result message contains if an actor exists or not.
229
+ """
230
+
231
+ @abstractmethod
232
+ async def destroy_actor(self, message: DestroyActorMessage) -> ResultMessageType:
233
+ """
234
+ Destroy an actor.
235
+
236
+ Parameters
237
+ ----------
238
+ message: DestroyActorMessage
239
+ message to destroy an actor.
240
+
241
+ Returns
242
+ -------
243
+ result_message
244
+ result or error message.
245
+ """
246
+
247
+ @abstractmethod
248
+ async def actor_ref(self, message: ActorRefMessage) -> ResultMessageType:
249
+ """
250
+ Get an actor's ref.
251
+
252
+ Parameters
253
+ ----------
254
+ message: ActorRefMessage
255
+ message to get an actor's ref.
256
+
257
+ Returns
258
+ -------
259
+ result_message
260
+ result or error message.
261
+ """
262
+
263
+ @abstractmethod
264
+ async def send(self, message: SendMessage) -> ResultMessageType:
265
+ """
266
+ Send a message to some actor.
267
+
268
+ Parameters
269
+ ----------
270
+ message: SendMessage
271
+ Message to send.
272
+
273
+ Returns
274
+ -------
275
+ result_message
276
+ result or error message.
277
+ """
278
+
279
+ @abstractmethod
280
+ async def tell(self, message: TellMessage) -> ResultMessageType:
281
+ """
282
+ Tell message to some actor.
283
+
284
+ Parameters
285
+ ----------
286
+ message: TellMessage
287
+ Message to tell.
288
+
289
+ Returns
290
+ -------
291
+ result_message
292
+ result or error message.
293
+ """
294
+
295
+ @abstractmethod
296
+ async def cancel(self, message: CancelMessage) -> ResultMessageType:
297
+ """
298
+ Cancel message that sent
299
+
300
+ Parameters
301
+ ----------
302
+ message: CancelMessage
303
+ Cancel message.
304
+
305
+ Returns
306
+ -------
307
+ result_message
308
+ result or error message
309
+ """
310
+
311
+ def _sync_pool_config(self, actor_pool_config: ActorPoolConfig):
312
+ self._config = actor_pool_config
313
+ # remove router from global one
314
+ global_router = Router.get_instance()
315
+ global_router.remove_router(self._router) # type: ignore
316
+ # update router
317
+ self._router.set_mapping(actor_pool_config.external_to_internal_address_map)
318
+ # update global router
319
+ global_router.add_router(self._router) # type: ignore
320
+
321
+ async def handle_control_command(
322
+ self, message: ControlMessage
323
+ ) -> ResultMessageType:
324
+ """
325
+ Handle control command.
326
+
327
+ Parameters
328
+ ----------
329
+ message: ControlMessage
330
+ Control message.
331
+
332
+ Returns
333
+ -------
334
+ result_message
335
+ result or error message.
336
+ """
337
+ with _ErrorProcessor(
338
+ self.external_address, message.message_id, protocol=message.protocol
339
+ ) as processor:
340
+ content: bool | ActorPoolConfig = True
341
+ if message.control_message_type == ControlMessageType.stop:
342
+ await self.stop()
343
+ elif message.control_message_type == ControlMessageType.sync_config:
344
+ self._sync_pool_config(message.content)
345
+ elif message.control_message_type == ControlMessageType.get_config:
346
+ if message.content == "main_pool_address":
347
+ main_process_index = self._config.get_process_indexes()[0]
348
+ content = self._config.get_pool_config(main_process_index)[
349
+ "external_address"
350
+ ][0]
351
+ else:
352
+ content = self._config
353
+ else: # pragma: no cover
354
+ raise TypeError(
355
+ f"Unable to handle control message "
356
+ f"with type {message.control_message_type}"
357
+ )
358
+ processor.result = ResultMessage(
359
+ message.message_id, content, protocol=message.protocol
360
+ )
361
+
362
+ return processor.result
363
+
364
+ async def _run_coro(self, message_id: bytes, coro: Coroutine):
365
+ self._process_messages[message_id] = asyncio.tasks.current_task()
366
+ try:
367
+ return await coro
368
+ finally:
369
+ self._process_messages.pop(message_id, None)
370
+
371
+ async def _send_channel(
372
+ self, result: _MessageBase, channel: Channel, resend_failure: bool = True
373
+ ):
374
+ try:
375
+ await channel.send(result)
376
+ except (ChannelClosed, ConnectionResetError):
377
+ if not self._stopped.is_set():
378
+ raise
379
+ except Exception as ex:
380
+ logger.exception(
381
+ "Error when sending message %s from %s to %s",
382
+ result.message_id.hex(),
383
+ channel.local_address,
384
+ channel.dest_address,
385
+ )
386
+ if not resend_failure: # pragma: no cover
387
+ raise
388
+
389
+ with _ErrorProcessor(
390
+ self.external_address, result.message_id, result.protocol
391
+ ) as processor:
392
+ error_msg = (
393
+ f"Error when sending message {result.message_id.hex()}. "
394
+ f"Caused by {ex!r}. "
395
+ )
396
+ if isinstance(result, ErrorMessage):
397
+ format_tb = "\n".join(traceback.format_tb(result.traceback))
398
+ error_msg += (
399
+ f"\nOriginal error: {result.error!r}"
400
+ f"Traceback: \n{format_tb}"
401
+ )
402
+ else:
403
+ error_msg += "See server logs for more details"
404
+ raise SendMessageFailed(error_msg) from None
405
+ await self._send_channel(processor.result, channel, resend_failure=False)
406
+
407
+ async def process_message(self, message: _MessageBase, channel: Channel):
408
+ handler = self._message_handler[message.message_type]
409
+ with _ErrorProcessor(
410
+ self.external_address, message.message_id, message.protocol
411
+ ) as processor:
412
+ # use `%.500` to avoid print too long messages
413
+ with debug_async_timeout(
414
+ "process_message_timeout",
415
+ "Process message %.500s of channel %s timeout.",
416
+ message,
417
+ channel,
418
+ ):
419
+ processor.result = await self._run_coro(
420
+ message.message_id, handler(self, message)
421
+ )
422
+
423
+ await self._send_channel(processor.result, channel)
424
+
425
+ async def call(self, dest_address: str, message: _MessageBase) -> ResultMessageType:
426
+ return await self._caller.call(self._router, dest_address, message) # type: ignore
427
+
428
+ @staticmethod
429
+ def _parse_config(config: dict, kw: dict) -> dict:
430
+ actor_pool_config: ActorPoolConfig = config.pop("actor_pool_config")
431
+ kw["config"] = actor_pool_config
432
+ kw["process_index"] = process_index = config.pop("process_index")
433
+ curr_pool_config = actor_pool_config.get_pool_config(process_index)
434
+ kw["label"] = curr_pool_config["label"]
435
+ external_addresses = curr_pool_config["external_address"]
436
+ kw["external_address"] = external_addresses[0]
437
+ kw["internal_address"] = curr_pool_config["internal_address"]
438
+ kw["router"] = Router(
439
+ external_addresses,
440
+ gen_local_address(process_index),
441
+ actor_pool_config.external_to_internal_address_map,
442
+ comm_config=actor_pool_config.get_comm_config(),
443
+ )
444
+ kw["env"] = curr_pool_config["env"]
445
+
446
+ if config: # pragma: no cover
447
+ raise TypeError(
448
+ f"Creating pool got unexpected " f'arguments: {",".join(config)}'
449
+ )
450
+
451
+ return kw
452
+
453
+ @classmethod
454
+ @abstractmethod
455
+ async def create(cls, config: dict) -> "AbstractActorPool":
456
+ """
457
+ Create an actor pool.
458
+
459
+ Parameters
460
+ ----------
461
+ config: dict
462
+ configurations.
463
+
464
+ Returns
465
+ -------
466
+ actor_pool:
467
+ Actor pool.
468
+ """
469
+
470
+ async def start(self):
471
+ if self._stopped.is_set():
472
+ raise RuntimeError("pool has been stopped, cannot start again")
473
+ start_servers = [server.start() for server in self._servers]
474
+ await asyncio.gather(*start_servers)
475
+
476
+ async def join(self, timeout: float | None = None):
477
+ wait_stopped = asyncio.create_task(self._stopped.wait())
478
+
479
+ try:
480
+ await asyncio.wait_for(wait_stopped, timeout=timeout)
481
+ except (futures.TimeoutError, asyncio.TimeoutError): # pragma: no cover
482
+ wait_stopped.cancel()
483
+
484
+ async def stop(self):
485
+ try:
486
+ # clean global router
487
+ router = Router.get_instance()
488
+ if router is not None:
489
+ router.remove_router(self._router)
490
+ stop_tasks = []
491
+ # stop all servers
492
+ stop_tasks.extend([server.stop() for server in self._servers])
493
+ # stop all clients
494
+ stop_tasks.append(self._caller.stop())
495
+ await asyncio.gather(*stop_tasks)
496
+
497
+ self._servers = []
498
+ if self._asyncio_task_timeout_detector_task: # pragma: no cover
499
+ self._asyncio_task_timeout_detector_task.cancel()
500
+ finally:
501
+ self._stopped.set()
502
+
503
+ async def handle_copy_to_buffers_message(self, message) -> ResultMessage:
504
+ for addr, uid, start, _len, data in message.content:
505
+ buffer = BufferRef.get_buffer(BufferRef(addr, uid))
506
+ buffer[start : start + _len] = data
507
+ return ResultMessage(message_id=message.message_id, result=True)
508
+
509
+ async def handle_copy_to_fileobjs_message(self, message) -> ResultMessage:
510
+ for addr, uid, data in message.content:
511
+ file_obj = FileObjectRef.get_local_file_object(FileObjectRef(addr, uid))
512
+ await file_obj.write(data)
513
+ return ResultMessage(message_id=message.message_id, result=True)
514
+
515
+ @property
516
+ def stopped(self) -> bool:
517
+ return self._stopped.is_set()
518
+
519
+ async def _handle_ucx_meta_message(
520
+ self, message: _MessageBase, channel: Channel
521
+ ) -> bool:
522
+ if (
523
+ isinstance(message, ControlMessage)
524
+ and message.message_type == MessageType.control
525
+ and message.control_message_type == ControlMessageType.switch_to_copy_to
526
+ and isinstance(channel, UCXChannel)
527
+ ):
528
+ with _ErrorProcessor(
529
+ self.external_address, message.message_id, message.protocol
530
+ ) as processor:
531
+ # use `%.500` to avoid print too long messages
532
+ with debug_async_timeout(
533
+ "process_message_timeout",
534
+ "Process message %.500s of channel %s timeout.",
535
+ message,
536
+ channel,
537
+ ):
538
+ buffers = [
539
+ BufferRef.get_buffer(BufferRef(addr, uid))
540
+ for addr, uid in message.content
541
+ ]
542
+ await channel.recv_buffers(buffers)
543
+ processor.result = ResultMessage(
544
+ message_id=message.message_id, result=True
545
+ )
546
+ asyncio.create_task(self._send_channel(processor.result, channel))
547
+ return True
548
+ return False
549
+
550
+ async def on_new_channel(self, channel: Channel):
551
+ while not self._stopped.is_set():
552
+ try:
553
+ message = await channel.recv()
554
+ except EOFError:
555
+ # no data to read, check channel
556
+ try:
557
+ await channel.close()
558
+ except (ConnectionError, EOFError):
559
+ # close failed, ignore
560
+ pass
561
+ return
562
+ if await self._handle_ucx_meta_message(message, channel):
563
+ continue
564
+ asyncio.create_task(self.process_message(message, channel))
565
+ # delete to release the reference of message
566
+ del message
567
+ await asyncio.sleep(0)
568
+
569
+ async def __aenter__(self):
570
+ await self.start()
571
+ return self
572
+
573
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
574
+ await self.stop()
575
+
576
+
577
+ class ActorPoolBase(AbstractActorPool, metaclass=ABCMeta):
578
+ __slots__ = ()
579
+
580
+ @implements(AbstractActorPool.create_actor)
581
+ async def create_actor(self, message: CreateActorMessage) -> ResultMessageType:
582
+ with _ErrorProcessor(
583
+ self.external_address, message.message_id, message.protocol
584
+ ) as processor:
585
+ actor_id = message.actor_id
586
+ if actor_id in self._actors:
587
+ raise ActorAlreadyExist(
588
+ f"Actor {actor_id!r} already exist, cannot create"
589
+ )
590
+
591
+ actor = message.actor_cls(*message.args, **message.kwargs)
592
+ actor.uid = actor_id
593
+ actor.address = address = self.external_address
594
+ self._actors[actor_id] = actor
595
+ await self._run_coro(message.message_id, actor.__post_create__())
596
+
597
+ result = ActorRef(address, actor_id)
598
+ # ensemble result message
599
+ processor.result = ResultMessage(
600
+ message.message_id, result, protocol=message.protocol
601
+ )
602
+ return processor.result
603
+
604
+ @implements(AbstractActorPool.has_actor)
605
+ async def has_actor(self, message: HasActorMessage) -> ResultMessage:
606
+ result = ResultMessage(
607
+ message.message_id,
608
+ message.actor_ref.uid in self._actors,
609
+ protocol=message.protocol,
610
+ )
611
+ return result
612
+
613
+ @implements(AbstractActorPool.destroy_actor)
614
+ async def destroy_actor(self, message: DestroyActorMessage) -> ResultMessageType:
615
+ with _ErrorProcessor(
616
+ self.external_address, message.message_id, message.protocol
617
+ ) as processor:
618
+ actor_id = message.actor_ref.uid
619
+ try:
620
+ actor = self._actors[actor_id]
621
+ except KeyError:
622
+ raise ActorNotExist(f"Actor {actor_id} does not exist")
623
+ await self._run_coro(message.message_id, actor.__pre_destroy__())
624
+ del self._actors[actor_id]
625
+
626
+ processor.result = ResultMessage(
627
+ message.message_id, actor_id, protocol=message.protocol
628
+ )
629
+ return processor.result
630
+
631
+ @implements(AbstractActorPool.actor_ref)
632
+ async def actor_ref(self, message: ActorRefMessage) -> ResultMessageType:
633
+ with _ErrorProcessor(
634
+ self.external_address, message.message_id, message.protocol
635
+ ) as processor:
636
+ actor_id = message.actor_ref.uid
637
+ if actor_id not in self._actors:
638
+ raise ActorNotExist(f"Actor {actor_id} does not exist")
639
+ result = ResultMessage(
640
+ message.message_id,
641
+ ActorRef(self.external_address, actor_id),
642
+ protocol=message.protocol,
643
+ )
644
+ processor.result = result
645
+ return processor.result
646
+
647
+ @implements(AbstractActorPool.send)
648
+ async def send(self, message: SendMessage) -> ResultMessageType:
649
+ with _ErrorProcessor(
650
+ self.external_address, message.message_id, message.protocol
651
+ ) as processor, record_message_trace(message):
652
+ actor_id = message.actor_ref.uid
653
+ if actor_id not in self._actors:
654
+ raise ActorNotExist(f"Actor {actor_id} does not exist")
655
+ coro = self._actors[actor_id].__on_receive__(message.content)
656
+ result = await self._run_coro(message.message_id, coro)
657
+ processor.result = ResultMessage(
658
+ message.message_id,
659
+ result,
660
+ protocol=message.protocol,
661
+ profiling_context=message.profiling_context,
662
+ )
663
+ return processor.result
664
+
665
+ @implements(AbstractActorPool.tell)
666
+ async def tell(self, message: TellMessage) -> ResultMessageType:
667
+ with _ErrorProcessor(
668
+ self.external_address, message.message_id, message.protocol
669
+ ) as processor:
670
+ actor_id = message.actor_ref.uid
671
+ if actor_id not in self._actors: # pragma: no cover
672
+ raise ActorNotExist(f"Actor {actor_id} does not exist")
673
+ call = self._actors[actor_id].__on_receive__(message.content)
674
+ # asynchronously run, tell does not care about result
675
+ asyncio.create_task(call)
676
+ await asyncio.sleep(0)
677
+ processor.result = ResultMessage(
678
+ message.message_id,
679
+ None,
680
+ protocol=message.protocol,
681
+ profiling_context=message.profiling_context,
682
+ )
683
+ return processor.result
684
+
685
+ @implements(AbstractActorPool.cancel)
686
+ async def cancel(self, message: CancelMessage) -> ResultMessageType:
687
+ with _ErrorProcessor(
688
+ self.external_address, message.message_id, message.protocol
689
+ ) as processor:
690
+ future = self._process_messages.get(message.cancel_message_id)
691
+ if future is None or future.done(): # pragma: no cover
692
+ raise CannotCancelTask(
693
+ "Task not exists, maybe it is done or cancelled already"
694
+ )
695
+ future.cancel()
696
+ processor.result = ResultMessage(
697
+ message.message_id, True, protocol=message.protocol
698
+ )
699
+ return processor.result
700
+
701
+ @staticmethod
702
+ def _set_global_router(router: Router):
703
+ # be cautious about setting global router
704
+ # for instance, multiple main pool may be created in the same process
705
+
706
+ # get default router or create an empty one
707
+ default_router = Router.get_instance_or_empty()
708
+ Router.set_instance(default_router)
709
+ # append this router to global
710
+ default_router.add_router(router)
711
+
712
+ @staticmethod
713
+ def _update_stored_addresses(
714
+ servers: list[Server],
715
+ raw_addresses: list[str],
716
+ actor_pool_config: ActorPoolConfig,
717
+ kw: dict,
718
+ ):
719
+ process_index = kw["process_index"]
720
+ curr_pool_config = actor_pool_config.get_pool_config(process_index)
721
+ external_addresses = curr_pool_config["external_address"]
722
+ external_address_set = set(external_addresses)
723
+
724
+ kw["servers"] = servers
725
+
726
+ new_external_addresses = [
727
+ server.address
728
+ for server, raw_address in zip(servers, raw_addresses)
729
+ if raw_address in external_address_set
730
+ ]
731
+
732
+ if external_address_set != set(new_external_addresses):
733
+ external_addresses = new_external_addresses
734
+ actor_pool_config.reset_pool_external_address(
735
+ process_index, external_addresses
736
+ )
737
+ external_addresses = curr_pool_config["external_address"]
738
+
739
+ logger.debug(
740
+ "External address of process index %s updated to %s",
741
+ process_index,
742
+ external_addresses[0],
743
+ )
744
+ if kw["internal_address"] == kw["external_address"]:
745
+ # internal address may be the same as external address in Windows
746
+ kw["internal_address"] = external_addresses[0]
747
+ kw["external_address"] = external_addresses[0]
748
+
749
+ kw["router"] = Router(
750
+ external_addresses,
751
+ gen_local_address(process_index),
752
+ actor_pool_config.external_to_internal_address_map,
753
+ comm_config=actor_pool_config.get_comm_config(),
754
+ )
755
+
756
+ @classmethod
757
+ async def _create_servers(
758
+ cls, addresses: list[str], channel_handler: Callable, config: dict
759
+ ):
760
+ assert len(set(addresses)) == len(addresses)
761
+ # create servers
762
+ create_server_tasks = []
763
+ for addr in addresses:
764
+ server_type = get_server_type(addr)
765
+ extra_config = server_type.parse_config(config)
766
+ server_config = dict(address=addr, handle_channel=channel_handler)
767
+ server_config.update(extra_config)
768
+ task = asyncio.create_task(server_type.create(server_config))
769
+ create_server_tasks.append(task)
770
+
771
+ await asyncio.gather(*create_server_tasks)
772
+ return [f.result() for f in create_server_tasks]
773
+
774
+ @classmethod
775
+ @implements(AbstractActorPool.create)
776
+ async def create(cls, config: dict) -> "ActorPoolType":
777
+ config = config.copy()
778
+ kw: dict[str, Any] = dict()
779
+ cls._parse_config(config, kw)
780
+ process_index: int = kw["process_index"]
781
+ actor_pool_config = kw["config"] # type: ActorPoolConfig
782
+ cur_pool_config = actor_pool_config.get_pool_config(process_index)
783
+ external_addresses = cur_pool_config["external_address"]
784
+ internal_address = kw["internal_address"]
785
+
786
+ # import predefined modules
787
+ modules = cur_pool_config["modules"] or []
788
+ for mod in modules:
789
+ __import__(mod, globals(), locals(), [])
790
+ # make sure all lazy imports loaded
791
+ with _disable_log_temporally():
792
+ TypeDispatcher.reload_all_lazy_handlers()
793
+
794
+ def handle_channel(channel):
795
+ return pool.on_new_channel(channel)
796
+
797
+ # create servers
798
+ server_addresses = list(external_addresses)
799
+ if internal_address:
800
+ server_addresses.append(internal_address)
801
+ server_addresses.append(gen_local_address(process_index))
802
+ server_addresses = sorted(set(server_addresses))
803
+ servers = await cls._create_servers(
804
+ server_addresses, handle_channel, actor_pool_config.get_comm_config()
805
+ )
806
+ cls._update_stored_addresses(servers, server_addresses, actor_pool_config, kw)
807
+
808
+ # set default router
809
+ # actor context would be able to use exact client
810
+ cls._set_global_router(kw["router"])
811
+
812
+ # create pool
813
+ pool = cls(**kw)
814
+ return pool # type: ignore
815
+
816
+
817
+ ActorPoolType = TypeVar("ActorPoolType", bound=AbstractActorPool)
818
+ MainActorPoolType = TypeVar("MainActorPoolType", bound="MainActorPoolBase")
819
+ SubProcessHandle = multiprocessing.Process
820
+
821
+
822
+ class SubActorPoolBase(ActorPoolBase):
823
+ __slots__ = ("_main_address", "_watch_main_pool_task")
824
+ _watch_main_pool_task: Optional[asyncio.Task]
825
+
826
+ def __init__(
827
+ self,
828
+ process_index: int,
829
+ label: str,
830
+ external_address: str,
831
+ internal_address: str,
832
+ env: dict,
833
+ router: Router,
834
+ config: ActorPoolConfig,
835
+ servers: list[Server],
836
+ main_address: str,
837
+ main_pool_pid: Optional[int],
838
+ ):
839
+ super().__init__(
840
+ process_index,
841
+ label,
842
+ external_address,
843
+ internal_address,
844
+ env,
845
+ router,
846
+ config,
847
+ servers,
848
+ )
849
+ self._main_address = main_address
850
+ if main_pool_pid:
851
+ self._watch_main_pool_task = asyncio.create_task(
852
+ self._watch_main_pool(main_pool_pid)
853
+ )
854
+ else:
855
+ self._watch_main_pool_task = None
856
+
857
+ async def _watch_main_pool(self, main_pool_pid: int):
858
+ main_process = psutil.Process(main_pool_pid)
859
+ while not self.stopped:
860
+ try:
861
+ await asyncio.to_thread(main_process.status)
862
+ await asyncio.sleep(0.1)
863
+ continue
864
+ except (psutil.NoSuchProcess, ProcessLookupError, asyncio.CancelledError):
865
+ # main pool died
866
+ break
867
+
868
+ if not self.stopped:
869
+ await self.stop()
870
+
871
+ async def notify_main_pool_to_destroy(
872
+ self, message: DestroyActorMessage
873
+ ): # pragma: no cover
874
+ await self.call(self._main_address, message)
875
+
876
+ async def notify_main_pool_to_create(self, message: CreateActorMessage):
877
+ reg_message = ControlMessage(
878
+ new_message_id(),
879
+ self.external_address,
880
+ ControlMessageType.add_sub_pool_actor,
881
+ (self.external_address, message.allocate_strategy, message),
882
+ )
883
+ await self.call(self._main_address, reg_message)
884
+
885
+ @implements(AbstractActorPool.create_actor)
886
+ async def create_actor(self, message: CreateActorMessage) -> ResultMessageType:
887
+ result = await super().create_actor(message)
888
+ if not message.from_main:
889
+ await self.notify_main_pool_to_create(message)
890
+ return result
891
+
892
+ @implements(AbstractActorPool.actor_ref)
893
+ async def actor_ref(self, message: ActorRefMessage) -> ResultMessageType:
894
+ result = await super().actor_ref(message)
895
+ if isinstance(result, ErrorMessage):
896
+ # need a new message id to call main actor
897
+ main_message = ActorRefMessage(
898
+ new_message_id(),
899
+ create_actor_ref(self._main_address, message.actor_ref.uid),
900
+ )
901
+ result = await self.call(self._main_address, main_message)
902
+ # rewrite to message_id of the original request
903
+ result.message_id = message.message_id
904
+ return result
905
+
906
+ @implements(AbstractActorPool.destroy_actor)
907
+ async def destroy_actor(self, message: DestroyActorMessage) -> ResultMessageType:
908
+ result = await super().destroy_actor(message)
909
+ if isinstance(result, ResultMessage) and not message.from_main:
910
+ # sync back to main actor pool
911
+ await self.notify_main_pool_to_destroy(message)
912
+ return result
913
+
914
+ @implements(AbstractActorPool.handle_control_command)
915
+ async def handle_control_command(
916
+ self, message: ControlMessage
917
+ ) -> ResultMessageType:
918
+ if message.control_message_type == ControlMessageType.sync_config:
919
+ self._main_address = message.address
920
+ return await super().handle_control_command(message)
921
+
922
+ @staticmethod
923
+ def _parse_config(config: dict, kw: dict) -> dict:
924
+ main_pool_pid = config.pop("main_pool_pid", None)
925
+ kw = AbstractActorPool._parse_config(config, kw)
926
+ pool_config: ActorPoolConfig = kw["config"]
927
+ main_process_index = pool_config.get_process_indexes()[0]
928
+ kw["main_address"] = pool_config.get_pool_config(main_process_index)[
929
+ "external_address"
930
+ ][0]
931
+ kw["main_pool_pid"] = main_pool_pid
932
+ return kw
933
+
934
+ async def stop(self):
935
+ await super().stop()
936
+ if self._watch_main_pool_task:
937
+ self._watch_main_pool_task.cancel()
938
+ await self._watch_main_pool_task
939
+
940
+
941
+ class MainActorPoolBase(ActorPoolBase):
942
+ __slots__ = (
943
+ "_subprocess_start_method",
944
+ "_allocated_actors",
945
+ "sub_actor_pool_manager",
946
+ "_auto_recover",
947
+ "_monitor_task",
948
+ "_on_process_down",
949
+ "_on_process_recover",
950
+ "_recover_events",
951
+ "_allocation_lock",
952
+ "sub_processes",
953
+ )
954
+
955
+ def __init__(
956
+ self,
957
+ process_index: int,
958
+ label: str,
959
+ external_address: str,
960
+ internal_address: str,
961
+ env: dict,
962
+ router: Router,
963
+ config: ActorPoolConfig,
964
+ servers: list[Server],
965
+ subprocess_start_method: str | None = None,
966
+ auto_recover: str | bool = "actor",
967
+ on_process_down: Callable[[MainActorPoolType, str], None] | None = None,
968
+ on_process_recover: Callable[[MainActorPoolType, str], None] | None = None,
969
+ ):
970
+ super().__init__(
971
+ process_index,
972
+ label,
973
+ external_address,
974
+ internal_address,
975
+ env,
976
+ router,
977
+ config,
978
+ servers,
979
+ )
980
+ self._subprocess_start_method = subprocess_start_method
981
+
982
+ # auto recovering
983
+ self._auto_recover = auto_recover
984
+ self._monitor_task: Optional[asyncio.Task] = None
985
+ self._on_process_down = on_process_down
986
+ self._on_process_recover = on_process_recover
987
+ self._recover_events: dict[str, asyncio.Event] = dict()
988
+
989
+ # states
990
+ self._allocated_actors: allocated_type = {
991
+ addr: dict() for addr in self._config.get_external_addresses()
992
+ }
993
+ self._allocation_lock = threading.Lock()
994
+
995
+ self.sub_processes: dict[str, SubProcessHandle] = dict()
996
+
997
+ _process_index_gen = itertools.count()
998
+
999
+ @classmethod
1000
+ def process_index_gen(cls, address):
1001
+ # make sure different processes does not share process indexes
1002
+ pid = os.getpid()
1003
+ for idx in cls._process_index_gen:
1004
+ yield pid << 16 + idx
1005
+
1006
+ @property
1007
+ def _sub_processes(self):
1008
+ return self.sub_processes
1009
+
1010
+ @implements(AbstractActorPool.create_actor)
1011
+ async def create_actor(self, message: CreateActorMessage) -> ResultMessageType:
1012
+ with _ErrorProcessor(
1013
+ address=self.external_address,
1014
+ message_id=message.message_id,
1015
+ protocol=message.protocol,
1016
+ ) as processor:
1017
+ allocate_strategy = message.allocate_strategy
1018
+ with self._allocation_lock:
1019
+ # get allocated address according to corresponding strategy
1020
+ address = allocate_strategy.get_allocated_address(
1021
+ self._config, self._allocated_actors
1022
+ )
1023
+ # set placeholder to make sure this label is occupied
1024
+ self._allocated_actors[address][None] = (allocate_strategy, message)
1025
+ if address == self.external_address:
1026
+ # creating actor on main actor pool
1027
+ result = await super().create_actor(message)
1028
+ if isinstance(result, ResultMessage):
1029
+ self._allocated_actors[self.external_address][result.result] = (
1030
+ allocate_strategy,
1031
+ message,
1032
+ )
1033
+ processor.result = result
1034
+ else:
1035
+ # creating actor on sub actor pool
1036
+ # rewrite allocate strategy to AddressSpecified
1037
+ new_allocate_strategy = AddressSpecified(address)
1038
+ new_create_actor_message = CreateActorMessage(
1039
+ message.message_id,
1040
+ message.actor_cls,
1041
+ message.actor_id,
1042
+ message.args,
1043
+ message.kwargs,
1044
+ allocate_strategy=new_allocate_strategy,
1045
+ from_main=True,
1046
+ protocol=message.protocol,
1047
+ message_trace=message.message_trace,
1048
+ )
1049
+ result = await self.call(address, new_create_actor_message)
1050
+ if isinstance(result, ResultMessage):
1051
+ self._allocated_actors[address][result.result] = (
1052
+ allocate_strategy,
1053
+ new_create_actor_message,
1054
+ )
1055
+ processor.result = result
1056
+
1057
+ # revert placeholder
1058
+ self._allocated_actors[address].pop(None, None)
1059
+
1060
+ return processor.result
1061
+
1062
+ @implements(AbstractActorPool.has_actor)
1063
+ async def has_actor(self, message: HasActorMessage) -> ResultMessage:
1064
+ actor_ref = message.actor_ref
1065
+ # lookup allocated
1066
+ for address, item in self._allocated_actors.items():
1067
+ ref = create_actor_ref(address, to_binary(actor_ref.uid))
1068
+ if ref in item:
1069
+ return ResultMessage(
1070
+ message.message_id, True, protocol=message.protocol
1071
+ )
1072
+
1073
+ return ResultMessage(message.message_id, False, protocol=message.protocol)
1074
+
1075
+ @implements(AbstractActorPool.destroy_actor)
1076
+ async def destroy_actor(self, message: DestroyActorMessage) -> ResultMessageType:
1077
+ actor_ref_message = ActorRefMessage(
1078
+ message.message_id, message.actor_ref, protocol=message.protocol
1079
+ )
1080
+ result = await self.actor_ref(actor_ref_message)
1081
+ if not isinstance(result, ResultMessage):
1082
+ return result
1083
+ real_actor_ref = result.result
1084
+ if real_actor_ref.address == self.external_address:
1085
+ result = await super().destroy_actor(message)
1086
+ if result.message_type == MessageType.error:
1087
+ return result
1088
+ del self._allocated_actors[self.external_address][real_actor_ref]
1089
+ return ResultMessage(
1090
+ message.message_id, real_actor_ref.uid, protocol=message.protocol
1091
+ )
1092
+ # remove allocated actor ref
1093
+ self._allocated_actors[real_actor_ref.address].pop(real_actor_ref, None)
1094
+ new_destroy_message = DestroyActorMessage(
1095
+ message.message_id,
1096
+ real_actor_ref,
1097
+ from_main=True,
1098
+ protocol=message.protocol,
1099
+ )
1100
+ return await self.call(real_actor_ref.address, new_destroy_message)
1101
+
1102
+ @implements(AbstractActorPool.send)
1103
+ async def send(self, message: SendMessage) -> ResultMessageType:
1104
+ if message.actor_ref.uid in self._actors:
1105
+ return await super().send(message)
1106
+ actor_ref_message = ActorRefMessage(
1107
+ message.message_id, message.actor_ref, protocol=message.protocol
1108
+ )
1109
+ result = await self.actor_ref(actor_ref_message)
1110
+ if not isinstance(result, ResultMessage):
1111
+ return result
1112
+ actor_ref = result.result
1113
+ new_send_message = SendMessage(
1114
+ message.message_id,
1115
+ actor_ref,
1116
+ message.content,
1117
+ protocol=message.protocol,
1118
+ message_trace=message.message_trace,
1119
+ )
1120
+ return await self.call(actor_ref.address, new_send_message)
1121
+
1122
+ @implements(AbstractActorPool.tell)
1123
+ async def tell(self, message: TellMessage) -> ResultMessageType:
1124
+ if message.actor_ref.uid in self._actors:
1125
+ return await super().tell(message)
1126
+ actor_ref_message = ActorRefMessage(
1127
+ message.message_id, message.actor_ref, protocol=message.protocol
1128
+ )
1129
+ result = await self.actor_ref(actor_ref_message)
1130
+ if not isinstance(result, ResultMessage):
1131
+ return result
1132
+ actor_ref = result.result
1133
+ new_tell_message = TellMessage(
1134
+ message.message_id,
1135
+ actor_ref,
1136
+ message.content,
1137
+ protocol=message.protocol,
1138
+ message_trace=message.message_trace,
1139
+ )
1140
+ return await self.call(actor_ref.address, new_tell_message)
1141
+
1142
+ @implements(AbstractActorPool.actor_ref)
1143
+ async def actor_ref(self, message: ActorRefMessage) -> ResultMessageType:
1144
+ actor_ref = message.actor_ref
1145
+ actor_ref.uid = to_binary(actor_ref.uid)
1146
+ if actor_ref.address == self.external_address and actor_ref.uid in self._actors:
1147
+ return ResultMessage(
1148
+ message.message_id, actor_ref, protocol=message.protocol
1149
+ )
1150
+
1151
+ # lookup allocated
1152
+ for address, item in self._allocated_actors.items():
1153
+ ref = create_actor_ref(address, actor_ref.uid)
1154
+ if ref in item:
1155
+ return ResultMessage(message.message_id, ref, protocol=message.protocol)
1156
+
1157
+ with _ErrorProcessor(
1158
+ self.external_address, message.message_id, protocol=message.protocol
1159
+ ) as processor:
1160
+ raise ActorNotExist(
1161
+ f"Actor {actor_ref.uid} does not exist in {actor_ref.address}"
1162
+ )
1163
+
1164
+ return processor.result
1165
+
1166
+ @implements(AbstractActorPool.cancel)
1167
+ async def cancel(self, message: CancelMessage) -> ResultMessageType:
1168
+ if message.address == self.external_address:
1169
+ # local message
1170
+ return await super().cancel(message)
1171
+ # redirect to sub pool
1172
+ return await self.call(message.address, message)
1173
+
1174
+ @implements(AbstractActorPool.handle_control_command)
1175
+ async def handle_control_command(
1176
+ self, message: ControlMessage
1177
+ ) -> ResultMessageType:
1178
+ with _ErrorProcessor(
1179
+ self.external_address, message.message_id, message.protocol
1180
+ ) as processor:
1181
+ if message.address == self.external_address:
1182
+ if message.control_message_type == ControlMessageType.sync_config:
1183
+ # sync config, need to notify all sub pools
1184
+ tasks = []
1185
+ for addr in self.sub_processes:
1186
+ control_message = ControlMessage(
1187
+ new_message_id(),
1188
+ message.address,
1189
+ message.control_message_type,
1190
+ message.content,
1191
+ protocol=message.protocol,
1192
+ message_trace=message.message_trace,
1193
+ )
1194
+ tasks.append(
1195
+ asyncio.create_task(self.call(addr, control_message))
1196
+ )
1197
+ # call super
1198
+ task = asyncio.create_task(super().handle_control_command(message))
1199
+ tasks.append(task)
1200
+ await asyncio.gather(*tasks)
1201
+ processor.result = await task
1202
+ else:
1203
+ processor.result = await super().handle_control_command(message)
1204
+ elif message.control_message_type == ControlMessageType.stop:
1205
+ timeout, force = (
1206
+ message.content if message.content is not None else (None, False)
1207
+ )
1208
+ await self.stop_sub_pool(
1209
+ message.address,
1210
+ self.sub_processes[message.address],
1211
+ timeout=timeout,
1212
+ force=force,
1213
+ )
1214
+ processor.result = ResultMessage(
1215
+ message.message_id, True, protocol=message.protocol
1216
+ )
1217
+ elif message.control_message_type == ControlMessageType.wait_pool_recovered:
1218
+ if self._auto_recover and message.address not in self._recover_events:
1219
+ self._recover_events[message.address] = asyncio.Event()
1220
+
1221
+ event = self._recover_events.get(message.address, None)
1222
+ if event is not None:
1223
+ await event.wait()
1224
+ processor.result = ResultMessage(
1225
+ message.message_id, True, protocol=message.protocol
1226
+ )
1227
+ elif message.control_message_type == ControlMessageType.add_sub_pool_actor:
1228
+ address, allocate_strategy, create_message = message.content
1229
+ create_message.from_main = True
1230
+ ref = create_actor_ref(address, to_binary(create_message.actor_id))
1231
+ self._allocated_actors[address][ref] = (
1232
+ allocate_strategy,
1233
+ create_message,
1234
+ )
1235
+ processor.result = ResultMessage(
1236
+ message.message_id, True, protocol=message.protocol
1237
+ )
1238
+ else:
1239
+ processor.result = await self.call(message.address, message)
1240
+ return processor.result
1241
+
1242
+ @staticmethod
1243
+ def _parse_config(config: dict, kw: dict) -> dict:
1244
+ kw["subprocess_start_method"] = config.pop("start_method", None)
1245
+ kw["auto_recover"] = config.pop("auto_recover", "actor")
1246
+ kw["on_process_down"] = config.pop("on_process_down", None)
1247
+ kw["on_process_recover"] = config.pop("on_process_recover", None)
1248
+ kw = AbstractActorPool._parse_config(config, kw)
1249
+ return kw
1250
+
1251
+ @classmethod
1252
+ @implements(AbstractActorPool.create)
1253
+ async def create(cls, config: dict) -> MainActorPoolType:
1254
+ config = config.copy()
1255
+ actor_pool_config: ActorPoolConfig = config.get("actor_pool_config") # type: ignore
1256
+ start_method = config.get("start_method", None)
1257
+ if "process_index" not in config:
1258
+ config["process_index"] = actor_pool_config.get_process_indexes()[0]
1259
+ curr_process_index = config.get("process_index")
1260
+ old_config_addresses = set(actor_pool_config.get_external_addresses())
1261
+
1262
+ tasks = []
1263
+ subpool_process_idxes = []
1264
+ # create sub actor pools
1265
+ n_sub_pool = actor_pool_config.n_pool - 1
1266
+ if n_sub_pool > 0:
1267
+ process_indexes = actor_pool_config.get_process_indexes()
1268
+ for process_index in process_indexes:
1269
+ if process_index == curr_process_index:
1270
+ continue
1271
+ create_pool_task = asyncio.create_task(
1272
+ cls.start_sub_pool(actor_pool_config, process_index, start_method)
1273
+ )
1274
+ await asyncio.sleep(0)
1275
+ # await create_pool_task
1276
+ tasks.append(create_pool_task)
1277
+ subpool_process_idxes.append(process_index)
1278
+
1279
+ processes, ext_addresses = await cls.wait_sub_pools_ready(tasks)
1280
+ if ext_addresses:
1281
+ for process_index, ext_address in zip(subpool_process_idxes, ext_addresses):
1282
+ actor_pool_config.reset_pool_external_address(
1283
+ process_index, ext_address
1284
+ )
1285
+
1286
+ # create main actor pool
1287
+ pool: MainActorPoolType = await super().create(config)
1288
+ addresses = actor_pool_config.get_external_addresses()[1:]
1289
+
1290
+ assert len(addresses) == len(
1291
+ processes
1292
+ ), f"addresses {addresses}, processes {processes}"
1293
+ for addr, proc in zip(addresses, processes):
1294
+ pool.attach_sub_process(addr, proc)
1295
+
1296
+ new_config_addresses = set(actor_pool_config.get_external_addresses())
1297
+ if old_config_addresses != new_config_addresses:
1298
+ control_message = ControlMessage(
1299
+ message_id=new_message_id(),
1300
+ address=pool.external_address,
1301
+ control_message_type=ControlMessageType.sync_config,
1302
+ content=actor_pool_config,
1303
+ )
1304
+ await pool.handle_control_command(control_message)
1305
+
1306
+ return pool
1307
+
1308
+ async def start_monitor(self):
1309
+ if self._monitor_task is None:
1310
+ self._monitor_task = asyncio.create_task(self.monitor_sub_pools())
1311
+ return self._monitor_task
1312
+
1313
+ @implements(AbstractActorPool.stop)
1314
+ async def stop(self):
1315
+ global_router = Router.get_instance()
1316
+ if global_router is not None:
1317
+ global_router.remove_router(self._router)
1318
+
1319
+ # turn off auto recover to avoid errors
1320
+ self._auto_recover = False
1321
+ self._stopped.set()
1322
+ if self._monitor_task and not self._monitor_task.done():
1323
+ await self._monitor_task
1324
+ self._monitor_task = None
1325
+ await self.stop_sub_pools()
1326
+ await super().stop()
1327
+
1328
+ @classmethod
1329
+ @abstractmethod
1330
+ async def start_sub_pool(
1331
+ cls,
1332
+ actor_pool_config: ActorPoolConfig,
1333
+ process_index: int,
1334
+ start_method: str | None = None,
1335
+ ):
1336
+ """Start a sub actor pool"""
1337
+
1338
+ @classmethod
1339
+ @abstractmethod
1340
+ async def wait_sub_pools_ready(cls, create_pool_tasks: list[asyncio.Task]):
1341
+ """Wait all sub pools ready"""
1342
+
1343
+ def attach_sub_process(self, external_address: str, process: SubProcessHandle):
1344
+ self.sub_processes[external_address] = process
1345
+
1346
+ async def stop_sub_pools(self):
1347
+ to_stop_processes: dict[str, SubProcessHandle] = dict() # type: ignore
1348
+ for address, process in self.sub_processes.items():
1349
+ if not await self.is_sub_pool_alive(process):
1350
+ continue
1351
+ to_stop_processes[address] = process
1352
+
1353
+ tasks = []
1354
+ for address, process in to_stop_processes.items():
1355
+ tasks.append(self.stop_sub_pool(address, process))
1356
+ await asyncio.gather(*tasks)
1357
+
1358
+ async def stop_sub_pool(
1359
+ self,
1360
+ address: str,
1361
+ process: SubProcessHandle,
1362
+ timeout: float | None = None,
1363
+ force: bool = False,
1364
+ ):
1365
+ if force:
1366
+ await self.kill_sub_pool(process, force=True)
1367
+ return
1368
+
1369
+ stop_message = ControlMessage(
1370
+ new_message_id(),
1371
+ address,
1372
+ ControlMessageType.stop,
1373
+ None,
1374
+ protocol=DEFAULT_PROTOCOL,
1375
+ )
1376
+ try:
1377
+ if timeout is None:
1378
+ message = await self.call(address, stop_message)
1379
+ if isinstance(message, ErrorMessage): # pragma: no cover
1380
+ raise message.as_instanceof_cause()
1381
+ else:
1382
+ call = asyncio.create_task(self.call(address, stop_message))
1383
+ try:
1384
+ await asyncio.wait_for(call, timeout)
1385
+ except (futures.TimeoutError, asyncio.TimeoutError): # pragma: no cover
1386
+ # timeout, just let kill to finish it
1387
+ force = True
1388
+ except (ConnectionError, ServerClosed): # pragma: no cover
1389
+ # process dead maybe, ignore it
1390
+ pass
1391
+ # kill process
1392
+ await self.kill_sub_pool(process, force=force)
1393
+
1394
+ @abstractmethod
1395
+ async def kill_sub_pool(self, process: SubProcessHandle, force: bool = False):
1396
+ """Kill a sub actor pool"""
1397
+
1398
+ @abstractmethod
1399
+ async def is_sub_pool_alive(self, process: SubProcessHandle):
1400
+ """
1401
+ Check whether sub pool process is alive
1402
+ Parameters
1403
+ ----------
1404
+ process : SubProcessHandle
1405
+ sub pool process handle
1406
+ Returns
1407
+ -------
1408
+ bool
1409
+ """
1410
+
1411
+ @abstractmethod
1412
+ def recover_sub_pool(self, address):
1413
+ """Recover a sub actor pool"""
1414
+
1415
+ def process_sub_pool_lost(self, address: str):
1416
+ if self._auto_recover in (False, "process"):
1417
+ # process down, when not auto_recover
1418
+ # or only recover process, remove all created actors
1419
+ self._allocated_actors[address] = dict()
1420
+
1421
+ async def monitor_sub_pools(self):
1422
+ try:
1423
+ while not self._stopped.is_set():
1424
+ # Copy sub_processes to avoid changes during recover.
1425
+ for address, process in list(self.sub_processes.items()):
1426
+ try:
1427
+ recover_events_discovered = address in self._recover_events
1428
+ if not await self.is_sub_pool_alive(
1429
+ process
1430
+ ): # pragma: no cover
1431
+ if self._on_process_down is not None:
1432
+ self._on_process_down(self, address)
1433
+ self.process_sub_pool_lost(address)
1434
+ if self._auto_recover:
1435
+ await self.recover_sub_pool(address)
1436
+ if self._on_process_recover is not None:
1437
+ self._on_process_recover(self, address)
1438
+ if recover_events_discovered:
1439
+ event = self._recover_events.pop(address)
1440
+ event.set()
1441
+ except asyncio.CancelledError:
1442
+ raise
1443
+ except RuntimeError as ex: # pragma: no cover
1444
+ if "cannot schedule new futures" not in str(ex):
1445
+ # to silence log when process exit, otherwise it
1446
+ # will raise "RuntimeError: cannot schedule new futures
1447
+ # after interpreter shutdown".
1448
+ logger.exception("Monitor sub pool %s failed", address)
1449
+ except Exception:
1450
+ # log the exception instead of stop monitoring the
1451
+ # sub pool silently.
1452
+ logger.exception("Monitor sub pool %s failed", address)
1453
+
1454
+ # check every half second
1455
+ await asyncio.sleep(0.5)
1456
+ except asyncio.CancelledError: # pragma: no cover
1457
+ # cancelled
1458
+ return
1459
+
1460
+ @classmethod
1461
+ @abstractmethod
1462
+ def get_external_addresses(
1463
+ cls,
1464
+ address: str,
1465
+ n_process: int | None = None,
1466
+ ports: list[int] | None = None,
1467
+ schemes: list[Optional[str]] | None = None,
1468
+ ):
1469
+ """Returns external addresses for n pool processes"""
1470
+
1471
+ @classmethod
1472
+ @abstractmethod
1473
+ def gen_internal_address(
1474
+ cls, process_index: int, external_address: str | None = None
1475
+ ) -> str | None:
1476
+ """Returns internal address for pool of specified process index"""
1477
+
1478
+
1479
+ async def create_actor_pool(
1480
+ address: str,
1481
+ *,
1482
+ pool_cls: Type[MainActorPoolType] | None = None,
1483
+ n_process: int | None = None,
1484
+ labels: list[str] | None = None,
1485
+ ports: list[int] | None = None,
1486
+ envs: list[dict] | None = None,
1487
+ external_address_schemes: list[Optional[str]] | None = None,
1488
+ enable_internal_addresses: list[bool] | None = None,
1489
+ subprocess_start_method: str | None = None,
1490
+ auto_recover: str | bool = "actor",
1491
+ modules: list[str] | None = None,
1492
+ suspend_sigint: bool | None = None,
1493
+ use_uvloop: str | bool = "auto",
1494
+ logging_conf: dict | None = None,
1495
+ on_process_down: Callable[[MainActorPoolType, str], None] | None = None,
1496
+ on_process_recover: Callable[[MainActorPoolType, str], None] | None = None,
1497
+ extra_conf: dict | None = None,
1498
+ **kwargs,
1499
+ ) -> MainActorPoolType:
1500
+ if n_process is None:
1501
+ n_process = multiprocessing.cpu_count()
1502
+ if labels and len(labels) != n_process + 1:
1503
+ raise ValueError(
1504
+ f"`labels` should be of size {n_process + 1}, got {len(labels)}"
1505
+ )
1506
+ if envs and len(envs) != n_process:
1507
+ raise ValueError(f"`envs` should be of size {n_process}, got {len(envs)}")
1508
+ if external_address_schemes and len(external_address_schemes) != n_process + 1:
1509
+ raise ValueError(
1510
+ f"`external_address_schemes` should be of size {n_process + 1}, "
1511
+ f"got {len(external_address_schemes)}"
1512
+ )
1513
+ if enable_internal_addresses and len(enable_internal_addresses) != n_process + 1:
1514
+ raise ValueError(
1515
+ f"`enable_internal_addresses` should be of size {n_process + 1}, "
1516
+ f"got {len(enable_internal_addresses)}"
1517
+ )
1518
+ elif not enable_internal_addresses:
1519
+ enable_internal_addresses = [True] * (n_process + 1)
1520
+ if auto_recover is True:
1521
+ auto_recover = "actor"
1522
+ if auto_recover not in ("actor", "process", False):
1523
+ raise ValueError(
1524
+ f'`auto_recover` should be one of "actor", "process", '
1525
+ f"True or False, got {auto_recover}"
1526
+ )
1527
+ if use_uvloop == "auto":
1528
+ try:
1529
+ import uvloop # noqa: F401 # pylint: disable=unused-variable
1530
+
1531
+ use_uvloop = True
1532
+ except ImportError:
1533
+ use_uvloop = False
1534
+
1535
+ assert pool_cls is not None
1536
+ external_addresses = pool_cls.get_external_addresses(
1537
+ address, n_process=n_process, ports=ports, schemes=external_address_schemes
1538
+ )
1539
+ actor_pool_config = ActorPoolConfig()
1540
+ actor_pool_config.add_metric_configs(kwargs.get("metrics", {}))
1541
+ # add main config
1542
+ process_index_gen = pool_cls.process_index_gen(address)
1543
+ main_process_index = next(process_index_gen)
1544
+ main_internal_address = (
1545
+ pool_cls.gen_internal_address(main_process_index, external_addresses[0])
1546
+ if enable_internal_addresses[0]
1547
+ else None
1548
+ )
1549
+ actor_pool_config.add_pool_conf(
1550
+ main_process_index,
1551
+ labels[0] if labels else None,
1552
+ main_internal_address,
1553
+ external_addresses[0],
1554
+ modules=modules,
1555
+ suspend_sigint=suspend_sigint,
1556
+ use_uvloop=use_uvloop, # type: ignore
1557
+ logging_conf=logging_conf,
1558
+ kwargs=kwargs,
1559
+ )
1560
+ # add sub configs
1561
+ for i in range(n_process):
1562
+ sub_process_index = next(process_index_gen)
1563
+ internal_address = (
1564
+ pool_cls.gen_internal_address(sub_process_index, external_addresses[i + 1])
1565
+ if enable_internal_addresses[i + 1]
1566
+ else None
1567
+ )
1568
+ actor_pool_config.add_pool_conf(
1569
+ sub_process_index,
1570
+ labels[i + 1] if labels else None,
1571
+ internal_address,
1572
+ external_addresses[i + 1],
1573
+ env=envs[i] if envs else None,
1574
+ modules=modules,
1575
+ suspend_sigint=suspend_sigint,
1576
+ use_uvloop=use_uvloop, # type: ignore
1577
+ logging_conf=logging_conf,
1578
+ kwargs=kwargs,
1579
+ )
1580
+ actor_pool_config.add_comm_config(extra_conf)
1581
+
1582
+ pool: MainActorPoolType = await pool_cls.create(
1583
+ {
1584
+ "actor_pool_config": actor_pool_config,
1585
+ "process_index": main_process_index,
1586
+ "start_method": subprocess_start_method,
1587
+ "auto_recover": auto_recover,
1588
+ "on_process_down": on_process_down,
1589
+ "on_process_recover": on_process_recover,
1590
+ }
1591
+ )
1592
+ await pool.start()
1593
+ return pool