xoscar 0.9.0__cp312-cp312-macosx_10_13_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. xoscar/__init__.py +61 -0
  2. xoscar/_utils.cpython-312-darwin.so +0 -0
  3. xoscar/_utils.pxd +36 -0
  4. xoscar/_utils.pyx +246 -0
  5. xoscar/_version.py +693 -0
  6. xoscar/aio/__init__.py +16 -0
  7. xoscar/aio/base.py +86 -0
  8. xoscar/aio/file.py +59 -0
  9. xoscar/aio/lru.py +228 -0
  10. xoscar/aio/parallelism.py +39 -0
  11. xoscar/api.py +527 -0
  12. xoscar/backend.py +67 -0
  13. xoscar/backends/__init__.py +14 -0
  14. xoscar/backends/allocate_strategy.py +160 -0
  15. xoscar/backends/communication/__init__.py +30 -0
  16. xoscar/backends/communication/base.py +315 -0
  17. xoscar/backends/communication/core.py +69 -0
  18. xoscar/backends/communication/dummy.py +253 -0
  19. xoscar/backends/communication/errors.py +20 -0
  20. xoscar/backends/communication/socket.py +444 -0
  21. xoscar/backends/communication/ucx.py +538 -0
  22. xoscar/backends/communication/utils.py +97 -0
  23. xoscar/backends/config.py +157 -0
  24. xoscar/backends/context.py +437 -0
  25. xoscar/backends/core.py +352 -0
  26. xoscar/backends/indigen/__init__.py +16 -0
  27. xoscar/backends/indigen/__main__.py +19 -0
  28. xoscar/backends/indigen/backend.py +51 -0
  29. xoscar/backends/indigen/driver.py +26 -0
  30. xoscar/backends/indigen/fate_sharing.py +221 -0
  31. xoscar/backends/indigen/pool.py +515 -0
  32. xoscar/backends/indigen/shared_memory.py +548 -0
  33. xoscar/backends/message.cpython-312-darwin.so +0 -0
  34. xoscar/backends/message.pyi +255 -0
  35. xoscar/backends/message.pyx +646 -0
  36. xoscar/backends/pool.py +1630 -0
  37. xoscar/backends/router.py +285 -0
  38. xoscar/backends/test/__init__.py +16 -0
  39. xoscar/backends/test/backend.py +38 -0
  40. xoscar/backends/test/pool.py +233 -0
  41. xoscar/batch.py +256 -0
  42. xoscar/collective/__init__.py +27 -0
  43. xoscar/collective/backend/__init__.py +13 -0
  44. xoscar/collective/backend/nccl_backend.py +160 -0
  45. xoscar/collective/common.py +102 -0
  46. xoscar/collective/core.py +737 -0
  47. xoscar/collective/process_group.py +687 -0
  48. xoscar/collective/utils.py +41 -0
  49. xoscar/collective/xoscar_pygloo.cpython-312-darwin.so +0 -0
  50. xoscar/collective/xoscar_pygloo.pyi +239 -0
  51. xoscar/constants.py +23 -0
  52. xoscar/context.cpython-312-darwin.so +0 -0
  53. xoscar/context.pxd +21 -0
  54. xoscar/context.pyx +368 -0
  55. xoscar/core.cpython-312-darwin.so +0 -0
  56. xoscar/core.pxd +51 -0
  57. xoscar/core.pyx +664 -0
  58. xoscar/debug.py +188 -0
  59. xoscar/driver.py +42 -0
  60. xoscar/errors.py +63 -0
  61. xoscar/libcpp.pxd +31 -0
  62. xoscar/metrics/__init__.py +21 -0
  63. xoscar/metrics/api.py +288 -0
  64. xoscar/metrics/backends/__init__.py +13 -0
  65. xoscar/metrics/backends/console/__init__.py +13 -0
  66. xoscar/metrics/backends/console/console_metric.py +82 -0
  67. xoscar/metrics/backends/metric.py +149 -0
  68. xoscar/metrics/backends/prometheus/__init__.py +13 -0
  69. xoscar/metrics/backends/prometheus/prometheus_metric.py +70 -0
  70. xoscar/nvutils.py +717 -0
  71. xoscar/profiling.py +260 -0
  72. xoscar/serialization/__init__.py +20 -0
  73. xoscar/serialization/aio.py +141 -0
  74. xoscar/serialization/core.cpython-312-darwin.so +0 -0
  75. xoscar/serialization/core.pxd +28 -0
  76. xoscar/serialization/core.pyi +57 -0
  77. xoscar/serialization/core.pyx +944 -0
  78. xoscar/serialization/cuda.py +111 -0
  79. xoscar/serialization/exception.py +48 -0
  80. xoscar/serialization/mlx.py +67 -0
  81. xoscar/serialization/numpy.py +82 -0
  82. xoscar/serialization/pyfury.py +37 -0
  83. xoscar/serialization/scipy.py +72 -0
  84. xoscar/serialization/torch.py +180 -0
  85. xoscar/utils.py +522 -0
  86. xoscar/virtualenv/__init__.py +34 -0
  87. xoscar/virtualenv/core.py +268 -0
  88. xoscar/virtualenv/platform.py +56 -0
  89. xoscar/virtualenv/utils.py +100 -0
  90. xoscar/virtualenv/uv.py +321 -0
  91. xoscar-0.9.0.dist-info/METADATA +230 -0
  92. xoscar-0.9.0.dist-info/RECORD +94 -0
  93. xoscar-0.9.0.dist-info/WHEEL +6 -0
  94. xoscar-0.9.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1630 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ # derived from copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import asyncio.subprocess
20
+ import concurrent.futures as futures
21
+ import contextlib
22
+ import itertools
23
+ import logging
24
+ import multiprocessing
25
+ import os
26
+ import threading
27
+ import traceback
28
+ from abc import ABC, ABCMeta, abstractmethod
29
+ from typing import Any, Callable, Coroutine, Optional, Type, TypeVar
30
+
31
+ import psutil
32
+
33
+ from .._utils import TypeDispatcher, create_actor_ref, to_binary
34
+ from ..api import Actor
35
+ from ..core import ActorRef, BufferRef, FileObjectRef, register_local_pool
36
+ from ..debug import debug_async_timeout, record_message_trace
37
+ from ..errors import (
38
+ ActorAlreadyExist,
39
+ ActorNotExist,
40
+ CannotCancelTask,
41
+ SendMessageFailed,
42
+ ServerClosed,
43
+ )
44
+ from ..metrics import init_metrics
45
+ from ..utils import implements, is_zero_ip, register_asyncio_task_timeout_detector
46
+ from .allocate_strategy import AddressSpecified, allocated_type
47
+ from .communication import (
48
+ Channel,
49
+ Server,
50
+ UCXChannel,
51
+ gen_local_address,
52
+ get_server_type,
53
+ )
54
+ from .communication.errors import ChannelClosed
55
+ from .config import ActorPoolConfig
56
+ from .core import ActorCaller, ResultMessageType
57
+ from .message import (
58
+ DEFAULT_PROTOCOL,
59
+ ActorRefMessage,
60
+ CancelMessage,
61
+ ControlMessage,
62
+ ControlMessageType,
63
+ CreateActorMessage,
64
+ DestroyActorMessage,
65
+ ErrorMessage,
66
+ ForwardMessage,
67
+ HasActorMessage,
68
+ MessageType,
69
+ ResultMessage,
70
+ SendMessage,
71
+ TellMessage,
72
+ _MessageBase,
73
+ new_message_id,
74
+ )
75
+ from .router import Router
76
+
77
+ logger = logging.getLogger(__name__)
78
+
79
+
80
+ @contextlib.contextmanager
81
+ def _disable_log_temporally():
82
+ if os.getenv("CUDA_VISIBLE_DEVICES") == "-1":
83
+ # disable logging when CUDA_VISIBLE_DEVICES == -1
84
+ # many logging comes from ptxcompiler may distract users
85
+ try:
86
+ logging.disable(level=logging.ERROR)
87
+ yield
88
+ finally:
89
+ logging.disable(level=logging.NOTSET)
90
+ else:
91
+ yield
92
+
93
+
94
+ class _ErrorProcessor:
95
+ def __init__(self, address: str, message_id: bytes, protocol):
96
+ self._address = address
97
+ self._message_id = message_id
98
+ self._protocol = protocol
99
+ self.result = None
100
+
101
+ def __enter__(self):
102
+ return self
103
+
104
+ def __exit__(self, exc_type, exc_val, exc_tb):
105
+ if self.result is None:
106
+ self.result = ErrorMessage(
107
+ self._message_id,
108
+ self._address,
109
+ os.getpid(),
110
+ exc_type,
111
+ exc_val,
112
+ exc_tb,
113
+ protocol=self._protocol,
114
+ )
115
+ return True
116
+
117
+
118
+ def _register_message_handler(pool_type: Type["AbstractActorPool"]):
119
+ pool_type._message_handler = dict()
120
+ for message_type, handler in [
121
+ (MessageType.create_actor, pool_type.create_actor),
122
+ (MessageType.destroy_actor, pool_type.destroy_actor),
123
+ (MessageType.has_actor, pool_type.has_actor),
124
+ (MessageType.actor_ref, pool_type.actor_ref),
125
+ (MessageType.send, pool_type.send),
126
+ (MessageType.tell, pool_type.tell),
127
+ (MessageType.cancel, pool_type.cancel),
128
+ (MessageType.forward, pool_type.forward),
129
+ (MessageType.control, pool_type.handle_control_command),
130
+ (MessageType.copy_to_buffers, pool_type.handle_copy_to_buffers_message),
131
+ (MessageType.copy_to_fileobjs, pool_type.handle_copy_to_fileobjs_message),
132
+ ]:
133
+ pool_type._message_handler[message_type] = handler # type: ignore
134
+ return pool_type
135
+
136
+
137
+ class AbstractActorPool(ABC):
138
+ __slots__ = (
139
+ "process_index",
140
+ "label",
141
+ "external_address",
142
+ "internal_address",
143
+ "env",
144
+ "_servers",
145
+ "_router",
146
+ "_config",
147
+ "_stopped",
148
+ "_actors",
149
+ "_caller",
150
+ "_process_messages",
151
+ "_asyncio_task_timeout_detector_task",
152
+ )
153
+
154
+ _message_handler: dict[MessageType, Callable]
155
+ _process_messages: dict[bytes, asyncio.Future | asyncio.Task | None]
156
+
157
+ def __init__(
158
+ self,
159
+ process_index: int,
160
+ label: str,
161
+ external_address: str,
162
+ internal_address: str,
163
+ env: dict,
164
+ router: Router,
165
+ config: ActorPoolConfig,
166
+ servers: list[Server],
167
+ ):
168
+ # register local pool for local actor lookup.
169
+ # The pool is weakrefed, so we don't need to unregister it.
170
+ if not is_zero_ip(external_address):
171
+ # Only register_local_pool when we listen on non-zero ip (because all-zero ip is wildcard address),
172
+ # avoid mistaken with another remote service listen on non-zero ip with the same port.
173
+ register_local_pool(external_address, self)
174
+ self.process_index = process_index
175
+ self.label = label
176
+ self.external_address = external_address
177
+ self.internal_address = internal_address
178
+ self.env = env
179
+ self._router = router
180
+ self._config = config
181
+ self._servers = servers
182
+
183
+ self._stopped = asyncio.Event()
184
+
185
+ # states
186
+ # actor id -> actor
187
+ self._actors: dict[bytes, Actor] = dict()
188
+ # message id -> future
189
+ self._process_messages = dict()
190
+
191
+ # manage async actor callers
192
+ self._caller = ActorCaller()
193
+ self._asyncio_task_timeout_detector_task = (
194
+ register_asyncio_task_timeout_detector()
195
+ )
196
+ # init metrics
197
+ metric_configs = self._config.get_metric_configs()
198
+ metric_backend = metric_configs.get("backend")
199
+ init_metrics(metric_backend, config=metric_configs.get(metric_backend))
200
+
201
+ @property
202
+ def router(self):
203
+ return self._router
204
+
205
+ @abstractmethod
206
+ async def create_actor(self, message: CreateActorMessage) -> ResultMessageType:
207
+ """
208
+ Create an actor.
209
+
210
+ Parameters
211
+ ----------
212
+ message: CreateActorMessage
213
+ message to create an actor.
214
+
215
+ Returns
216
+ -------
217
+ result_message
218
+ result or error message.
219
+ """
220
+
221
+ @abstractmethod
222
+ async def has_actor(self, message: HasActorMessage) -> ResultMessage:
223
+ """
224
+ Check if an actor exists or not.
225
+
226
+ Parameters
227
+ ----------
228
+ message: HasActorMessage
229
+ message
230
+
231
+ Returns
232
+ -------
233
+ result_message
234
+ result message contains if an actor exists or not.
235
+ """
236
+
237
+ @abstractmethod
238
+ async def destroy_actor(self, message: DestroyActorMessage) -> ResultMessageType:
239
+ """
240
+ Destroy an actor.
241
+
242
+ Parameters
243
+ ----------
244
+ message: DestroyActorMessage
245
+ message to destroy an actor.
246
+
247
+ Returns
248
+ -------
249
+ result_message
250
+ result or error message.
251
+ """
252
+
253
+ @abstractmethod
254
+ async def actor_ref(self, message: ActorRefMessage) -> ResultMessageType:
255
+ """
256
+ Get an actor's ref.
257
+
258
+ Parameters
259
+ ----------
260
+ message: ActorRefMessage
261
+ message to get an actor's ref.
262
+
263
+ Returns
264
+ -------
265
+ result_message
266
+ result or error message.
267
+ """
268
+
269
+ @abstractmethod
270
+ async def send(self, message: SendMessage) -> ResultMessageType:
271
+ """
272
+ Send a message to some actor.
273
+
274
+ Parameters
275
+ ----------
276
+ message: SendMessage
277
+ Message to send.
278
+
279
+ Returns
280
+ -------
281
+ result_message
282
+ result or error message.
283
+ """
284
+
285
+ @abstractmethod
286
+ async def tell(self, message: TellMessage) -> ResultMessageType:
287
+ """
288
+ Tell message to some actor.
289
+
290
+ Parameters
291
+ ----------
292
+ message: TellMessage
293
+ Message to tell.
294
+
295
+ Returns
296
+ -------
297
+ result_message
298
+ result or error message.
299
+ """
300
+
301
+ @abstractmethod
302
+ async def cancel(self, message: CancelMessage) -> ResultMessageType:
303
+ """
304
+ Cancel message that sent
305
+
306
+ Parameters
307
+ ----------
308
+ message: CancelMessage
309
+ Cancel message.
310
+
311
+ Returns
312
+ -------
313
+ result_message
314
+ result or error message
315
+ """
316
+
317
+ async def forward(self, message: ForwardMessage) -> ResultMessageType:
318
+ """
319
+ Forward message
320
+
321
+ Parameters
322
+ ----------
323
+ message: ForwardMessage
324
+ Forward message.
325
+
326
+ Returns
327
+ -------
328
+ result_message
329
+ result or error message
330
+ """
331
+ return await self.call(message.address, message.raw_message)
332
+
333
+ def _sync_pool_config(self, actor_pool_config: ActorPoolConfig):
334
+ self._config = actor_pool_config
335
+ # remove router from global one
336
+ global_router = Router.get_instance()
337
+ global_router.remove_router(self._router) # type: ignore
338
+ # update router
339
+ self._router.set_mapping(actor_pool_config.external_to_internal_address_map)
340
+ # update global router
341
+ global_router.add_router(self._router) # type: ignore
342
+
343
+ async def handle_control_command(
344
+ self, message: ControlMessage
345
+ ) -> ResultMessageType:
346
+ """
347
+ Handle control command.
348
+
349
+ Parameters
350
+ ----------
351
+ message: ControlMessage
352
+ Control message.
353
+
354
+ Returns
355
+ -------
356
+ result_message
357
+ result or error message.
358
+ """
359
+ with _ErrorProcessor(
360
+ self.external_address, message.message_id, protocol=message.protocol
361
+ ) as processor:
362
+ content: bool | ActorPoolConfig = True
363
+ if message.control_message_type == ControlMessageType.stop:
364
+ await self.stop()
365
+ elif message.control_message_type == ControlMessageType.sync_config:
366
+ self._sync_pool_config(message.content)
367
+ elif message.control_message_type == ControlMessageType.get_config:
368
+ if message.content == "main_pool_address":
369
+ main_process_index = self._config.get_process_indexes()[0]
370
+ content = self._config.get_pool_config(main_process_index)[
371
+ "external_address"
372
+ ][0]
373
+ else:
374
+ content = self._config
375
+ else: # pragma: no cover
376
+ raise TypeError(
377
+ f"Unable to handle control message "
378
+ f"with type {message.control_message_type}"
379
+ )
380
+ processor.result = ResultMessage(
381
+ message.message_id, content, protocol=message.protocol
382
+ )
383
+
384
+ return processor.result
385
+
386
+ async def _run_coro(self, message_id: bytes, coro: Coroutine):
387
+ self._process_messages[message_id] = asyncio.tasks.current_task()
388
+ try:
389
+ return await coro
390
+ finally:
391
+ self._process_messages.pop(message_id, None)
392
+
393
+ async def _send_channel(
394
+ self, result: _MessageBase, channel: Channel, resend_failure: bool = True
395
+ ):
396
+ try:
397
+ await channel.send(result)
398
+ except (ChannelClosed, ConnectionResetError):
399
+ if not self._stopped.is_set() and not channel.closed:
400
+ raise
401
+ except Exception as ex:
402
+ logger.exception(
403
+ "Error when sending message %s from %s to %s",
404
+ result.message_id.hex(),
405
+ channel.local_address,
406
+ channel.dest_address,
407
+ )
408
+ if not resend_failure: # pragma: no cover
409
+ raise
410
+
411
+ with _ErrorProcessor(
412
+ self.external_address, result.message_id, result.protocol
413
+ ) as processor:
414
+ error_msg = (
415
+ f"Error when sending message {result.message_id.hex()}. "
416
+ f"Caused by {ex!r}. "
417
+ )
418
+ if isinstance(result, ErrorMessage):
419
+ format_tb = "\n".join(traceback.format_tb(result.traceback))
420
+ error_msg += (
421
+ f"\nOriginal error: {result.error!r}"
422
+ f"Traceback: \n{format_tb}"
423
+ )
424
+ else:
425
+ error_msg += "See server logs for more details"
426
+ raise SendMessageFailed(error_msg) from None
427
+ await self._send_channel(processor.result, channel, resend_failure=False)
428
+
429
+ async def process_message(self, message: _MessageBase, channel: Channel):
430
+ handler = self._message_handler[message.message_type]
431
+ with _ErrorProcessor(
432
+ self.external_address, message.message_id, message.protocol
433
+ ) as processor:
434
+ # use `%.500` to avoid print too long messages
435
+ with debug_async_timeout(
436
+ "process_message_timeout",
437
+ "Process message %.500s of channel %s timeout.",
438
+ message,
439
+ channel,
440
+ ):
441
+ processor.result = await self._run_coro(
442
+ message.message_id, handler(self, message)
443
+ )
444
+
445
+ await self._send_channel(processor.result, channel)
446
+
447
+ async def call(self, dest_address: str, message: _MessageBase) -> ResultMessageType:
448
+ return await self._caller.call(self._router, dest_address, message) # type: ignore
449
+
450
+ @staticmethod
451
+ def _parse_config(config: dict, kw: dict) -> dict:
452
+ actor_pool_config: ActorPoolConfig = config.pop("actor_pool_config")
453
+ kw["config"] = actor_pool_config
454
+ kw["process_index"] = process_index = config.pop("process_index")
455
+ curr_pool_config = actor_pool_config.get_pool_config(process_index)
456
+ kw["label"] = curr_pool_config["label"]
457
+ external_addresses = curr_pool_config["external_address"]
458
+ kw["external_address"] = external_addresses[0]
459
+ kw["internal_address"] = curr_pool_config["internal_address"]
460
+ kw["router"] = Router(
461
+ external_addresses,
462
+ gen_local_address(process_index),
463
+ actor_pool_config.external_to_internal_address_map,
464
+ comm_config=actor_pool_config.get_comm_config(),
465
+ proxy_config=actor_pool_config.get_proxy_config(),
466
+ )
467
+ kw["env"] = curr_pool_config["env"]
468
+
469
+ if config: # pragma: no cover
470
+ raise TypeError(
471
+ f"Creating pool got unexpected " f'arguments: {",".join(config)}'
472
+ )
473
+
474
+ return kw
475
+
476
+ @classmethod
477
+ @abstractmethod
478
+ async def create(cls, config: dict) -> "AbstractActorPool":
479
+ """
480
+ Create an actor pool.
481
+
482
+ Parameters
483
+ ----------
484
+ config: dict
485
+ configurations.
486
+
487
+ Returns
488
+ -------
489
+ actor_pool:
490
+ Actor pool.
491
+ """
492
+
493
+ async def start(self):
494
+ if self._stopped.is_set():
495
+ raise RuntimeError("pool has been stopped, cannot start again")
496
+ start_servers = [server.start() for server in self._servers]
497
+ await asyncio.gather(*start_servers)
498
+
499
+ async def join(self, timeout: float | None = None):
500
+ wait_stopped = asyncio.create_task(self._stopped.wait())
501
+
502
+ try:
503
+ await asyncio.wait_for(wait_stopped, timeout=timeout)
504
+ except (futures.TimeoutError, asyncio.TimeoutError): # pragma: no cover
505
+ wait_stopped.cancel()
506
+
507
+ async def stop(self):
508
+ try:
509
+ # clean global router
510
+ router = Router.get_instance()
511
+ if router is not None:
512
+ router.remove_router(self._router)
513
+ stop_tasks = []
514
+ # stop all servers
515
+ stop_tasks.extend([server.stop() for server in self._servers])
516
+ # stop all clients
517
+ stop_tasks.append(self._caller.stop())
518
+ await asyncio.gather(*stop_tasks)
519
+
520
+ self._servers = []
521
+ if self._asyncio_task_timeout_detector_task: # pragma: no cover
522
+ self._asyncio_task_timeout_detector_task.cancel()
523
+ finally:
524
+ self._stopped.set()
525
+
526
+ async def handle_copy_to_buffers_message(self, message) -> ResultMessage:
527
+ for addr, uid, start, _len, data in message.content:
528
+ buffer = BufferRef.get_buffer(BufferRef(addr, uid))
529
+ buffer[start : start + _len] = data
530
+ return ResultMessage(message_id=message.message_id, result=True)
531
+
532
+ async def handle_copy_to_fileobjs_message(self, message) -> ResultMessage:
533
+ for addr, uid, data in message.content:
534
+ file_obj = FileObjectRef.get_local_file_object(FileObjectRef(addr, uid))
535
+ await file_obj.write(data)
536
+ return ResultMessage(message_id=message.message_id, result=True)
537
+
538
+ @property
539
+ def stopped(self) -> bool:
540
+ return self._stopped.is_set()
541
+
542
+ async def _handle_ucx_meta_message(
543
+ self, message: _MessageBase, channel: Channel
544
+ ) -> bool:
545
+ if (
546
+ isinstance(message, ControlMessage)
547
+ and message.message_type == MessageType.control
548
+ and message.control_message_type == ControlMessageType.switch_to_copy_to
549
+ and isinstance(channel, UCXChannel)
550
+ ):
551
+ with _ErrorProcessor(
552
+ self.external_address, message.message_id, message.protocol
553
+ ) as processor:
554
+ # use `%.500` to avoid print too long messages
555
+ with debug_async_timeout(
556
+ "process_message_timeout",
557
+ "Process message %.500s of channel %s timeout.",
558
+ message,
559
+ channel,
560
+ ):
561
+ buffers = [
562
+ BufferRef.get_buffer(BufferRef(addr, uid))
563
+ for addr, uid in message.content
564
+ ]
565
+ await channel.recv_buffers(buffers)
566
+ processor.result = ResultMessage(
567
+ message_id=message.message_id, result=True
568
+ )
569
+ asyncio.create_task(self._send_channel(processor.result, channel))
570
+ return True
571
+ return False
572
+
573
+ async def on_new_channel(self, channel: Channel):
574
+ try:
575
+ while not self._stopped.is_set():
576
+ try:
577
+ message = await channel.recv()
578
+ except (EOFError, ConnectionError, BrokenPipeError) as e:
579
+ logger.debug(f"pool: close connection due to {e}")
580
+ # no data to read, check channel
581
+ try:
582
+ await channel.close()
583
+ except (ConnectionError, EOFError):
584
+ # close failed, ignore
585
+ pass
586
+ return
587
+ if await self._handle_ucx_meta_message(message, channel):
588
+ continue
589
+ asyncio.create_task(self.process_message(message, channel))
590
+ # delete to release the reference of message
591
+ del message
592
+ await asyncio.sleep(0)
593
+ finally:
594
+ try:
595
+ await channel.close()
596
+ except: # noqa: E722 # nosec # pylint: disable=bare-except
597
+ # ignore all error if fail to close at last
598
+ pass
599
+
600
+ async def __aenter__(self):
601
+ await self.start()
602
+ return self
603
+
604
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
605
+ await self.stop()
606
+
607
+
608
+ class ActorPoolBase(AbstractActorPool, metaclass=ABCMeta):
609
+ __slots__ = ()
610
+
611
+ @implements(AbstractActorPool.create_actor)
612
+ async def create_actor(self, message: CreateActorMessage) -> ResultMessageType:
613
+ with _ErrorProcessor(
614
+ self.external_address, message.message_id, message.protocol
615
+ ) as processor:
616
+ actor_id = message.actor_id
617
+ if actor_id in self._actors:
618
+ raise ActorAlreadyExist(
619
+ f"Actor {actor_id!r} already exist, cannot create"
620
+ )
621
+
622
+ actor = message.actor_cls(*message.args, **message.kwargs)
623
+ actor.uid = actor_id
624
+ actor.address = address = self.external_address
625
+ self._actors[actor_id] = actor
626
+ await self._run_coro(message.message_id, actor.__post_create__())
627
+
628
+ proxies = self._router.get_proxies(address)
629
+ result = ActorRef(address, actor_id, proxy_addresses=proxies)
630
+ # ensemble result message
631
+ processor.result = ResultMessage(
632
+ message.message_id, result, protocol=message.protocol
633
+ )
634
+ return processor.result
635
+
636
+ @implements(AbstractActorPool.has_actor)
637
+ async def has_actor(self, message: HasActorMessage) -> ResultMessage:
638
+ result = ResultMessage(
639
+ message.message_id,
640
+ message.actor_ref.uid in self._actors,
641
+ protocol=message.protocol,
642
+ )
643
+ return result
644
+
645
+ @implements(AbstractActorPool.destroy_actor)
646
+ async def destroy_actor(self, message: DestroyActorMessage) -> ResultMessageType:
647
+ with _ErrorProcessor(
648
+ self.external_address, message.message_id, message.protocol
649
+ ) as processor:
650
+ actor_id = message.actor_ref.uid
651
+ try:
652
+ actor = self._actors[actor_id]
653
+ except KeyError:
654
+ raise ActorNotExist(f"Actor {actor_id} does not exist")
655
+ await self._run_coro(message.message_id, actor.__pre_destroy__())
656
+ del self._actors[actor_id]
657
+
658
+ processor.result = ResultMessage(
659
+ message.message_id, actor_id, protocol=message.protocol
660
+ )
661
+ return processor.result
662
+
663
+ @implements(AbstractActorPool.actor_ref)
664
+ async def actor_ref(self, message: ActorRefMessage) -> ResultMessageType:
665
+ with _ErrorProcessor(
666
+ self.external_address, message.message_id, message.protocol
667
+ ) as processor:
668
+ actor_id = message.actor_ref.uid
669
+ if actor_id not in self._actors:
670
+ raise ActorNotExist(f"Actor {actor_id} does not exist")
671
+ proxies = self._router.get_proxies(self.external_address)
672
+ result = ResultMessage(
673
+ message.message_id,
674
+ ActorRef(self.external_address, actor_id, proxy_addresses=proxies),
675
+ protocol=message.protocol,
676
+ )
677
+ processor.result = result
678
+ return processor.result
679
+
680
+ @implements(AbstractActorPool.send)
681
+ async def send(self, message: SendMessage) -> ResultMessageType:
682
+ with _ErrorProcessor(
683
+ self.external_address, message.message_id, message.protocol
684
+ ) as processor, record_message_trace(message):
685
+ actor_id = message.actor_ref.uid
686
+ if actor_id not in self._actors:
687
+ raise ActorNotExist(f"Actor {actor_id} does not exist")
688
+ coro = self._actors[actor_id].__on_receive__(message.content)
689
+ result = await self._run_coro(message.message_id, coro)
690
+ processor.result = ResultMessage(
691
+ message.message_id,
692
+ result,
693
+ protocol=message.protocol,
694
+ profiling_context=message.profiling_context,
695
+ )
696
+ return processor.result
697
+
698
+ @implements(AbstractActorPool.tell)
699
+ async def tell(self, message: TellMessage) -> ResultMessageType:
700
+ with _ErrorProcessor(
701
+ self.external_address, message.message_id, message.protocol
702
+ ) as processor:
703
+ actor_id = message.actor_ref.uid
704
+ if actor_id not in self._actors: # pragma: no cover
705
+ raise ActorNotExist(f"Actor {actor_id} does not exist")
706
+ call = self._actors[actor_id].__on_receive__(message.content)
707
+ # asynchronously run, tell does not care about result
708
+ asyncio.create_task(call)
709
+ await asyncio.sleep(0)
710
+ processor.result = ResultMessage(
711
+ message.message_id,
712
+ None,
713
+ protocol=message.protocol,
714
+ profiling_context=message.profiling_context,
715
+ )
716
+ return processor.result
717
+
718
+ @implements(AbstractActorPool.cancel)
719
+ async def cancel(self, message: CancelMessage) -> ResultMessageType:
720
+ with _ErrorProcessor(
721
+ self.external_address, message.message_id, message.protocol
722
+ ) as processor:
723
+ future = self._process_messages.get(message.cancel_message_id)
724
+ if future is None or future.done(): # pragma: no cover
725
+ raise CannotCancelTask(
726
+ "Task not exists, maybe it is done or cancelled already"
727
+ )
728
+ future.cancel()
729
+ processor.result = ResultMessage(
730
+ message.message_id, True, protocol=message.protocol
731
+ )
732
+ return processor.result
733
+
734
+ @staticmethod
735
+ def _set_global_router(router: Router):
736
+ # be cautious about setting global router
737
+ # for instance, multiple main pool may be created in the same process
738
+
739
+ # get default router or create an empty one
740
+ default_router = Router.get_instance_or_empty()
741
+ Router.set_instance(default_router)
742
+ # append this router to global
743
+ default_router.add_router(router)
744
+
745
+ @staticmethod
746
+ def _update_stored_addresses(
747
+ servers: list[Server],
748
+ raw_addresses: list[str],
749
+ actor_pool_config: ActorPoolConfig,
750
+ kw: dict,
751
+ ):
752
+ process_index = kw["process_index"]
753
+ curr_pool_config = actor_pool_config.get_pool_config(process_index)
754
+ external_addresses = curr_pool_config["external_address"]
755
+ external_address_set = set(external_addresses)
756
+
757
+ kw["servers"] = servers
758
+
759
+ new_external_addresses = [
760
+ server.address
761
+ for server, raw_address in zip(servers, raw_addresses)
762
+ if raw_address in external_address_set
763
+ ]
764
+
765
+ if external_address_set != set(new_external_addresses):
766
+ external_addresses = new_external_addresses
767
+ actor_pool_config.reset_pool_external_address(
768
+ process_index, external_addresses
769
+ )
770
+ external_addresses = curr_pool_config["external_address"]
771
+
772
+ logger.debug(
773
+ "External address of process index %s updated to %s",
774
+ process_index,
775
+ external_addresses[0],
776
+ )
777
+ if kw["internal_address"] == kw["external_address"]:
778
+ # internal address may be the same as external address in Windows
779
+ kw["internal_address"] = external_addresses[0]
780
+ kw["external_address"] = external_addresses[0]
781
+
782
+ kw["router"] = Router(
783
+ external_addresses,
784
+ gen_local_address(process_index),
785
+ actor_pool_config.external_to_internal_address_map,
786
+ comm_config=actor_pool_config.get_comm_config(),
787
+ proxy_config=actor_pool_config.get_proxy_config(),
788
+ )
789
+
790
+ @classmethod
791
+ async def _create_servers(
792
+ cls, addresses: list[str], channel_handler: Callable, config: dict
793
+ ):
794
+ assert len(set(addresses)) == len(addresses)
795
+ # create servers
796
+ create_server_tasks = []
797
+ for addr in addresses:
798
+ server_type = get_server_type(addr)
799
+ extra_config = server_type.parse_config(config)
800
+ server_config = dict(address=addr, handle_channel=channel_handler)
801
+ server_config.update(extra_config)
802
+ task = asyncio.create_task(server_type.create(server_config))
803
+ create_server_tasks.append(task)
804
+
805
+ await asyncio.gather(*create_server_tasks)
806
+ return [f.result() for f in create_server_tasks]
807
+
808
+ @classmethod
809
+ @implements(AbstractActorPool.create)
810
+ async def create(cls, config: dict) -> "ActorPoolType":
811
+ config = config.copy()
812
+ kw: dict[str, Any] = dict()
813
+ cls._parse_config(config, kw)
814
+ process_index: int = kw["process_index"]
815
+ actor_pool_config = kw["config"] # type: ActorPoolConfig
816
+ cur_pool_config = actor_pool_config.get_pool_config(process_index)
817
+ external_addresses = cur_pool_config["external_address"]
818
+ internal_address = kw["internal_address"]
819
+
820
+ # import predefined modules
821
+ modules = cur_pool_config["modules"] or []
822
+ for mod in modules:
823
+ __import__(mod, globals(), locals(), [])
824
+ # make sure all lazy imports loaded
825
+ with _disable_log_temporally():
826
+ TypeDispatcher.reload_all_lazy_handlers()
827
+
828
+ def handle_channel(channel):
829
+ return pool.on_new_channel(channel)
830
+
831
+ # create servers
832
+ server_addresses = list(external_addresses)
833
+ if internal_address:
834
+ server_addresses.append(internal_address)
835
+ server_addresses.append(gen_local_address(process_index))
836
+ server_addresses = sorted(set(server_addresses))
837
+ servers = await cls._create_servers(
838
+ server_addresses, handle_channel, actor_pool_config.get_comm_config()
839
+ )
840
+ cls._update_stored_addresses(servers, server_addresses, actor_pool_config, kw)
841
+
842
+ # set default router
843
+ # actor context would be able to use exact client
844
+ cls._set_global_router(kw["router"])
845
+
846
+ # create pool
847
+ pool = cls(**kw)
848
+ return pool # type: ignore
849
+
850
+
851
+ ActorPoolType = TypeVar("ActorPoolType", bound=AbstractActorPool)
852
+ MainActorPoolType = TypeVar("MainActorPoolType", bound="MainActorPoolBase")
853
+ SubProcessHandle = asyncio.subprocess.Process
854
+
855
+
856
+ class SubActorPoolBase(ActorPoolBase):
857
+ __slots__ = ("_main_address", "_watch_main_pool_task")
858
+ _watch_main_pool_task: Optional[asyncio.Task]
859
+
860
+ def __init__(
861
+ self,
862
+ process_index: int,
863
+ label: str,
864
+ external_address: str,
865
+ internal_address: str,
866
+ env: dict,
867
+ router: Router,
868
+ config: ActorPoolConfig,
869
+ servers: list[Server],
870
+ main_address: str,
871
+ main_pool_pid: Optional[int],
872
+ ):
873
+ super().__init__(
874
+ process_index,
875
+ label,
876
+ external_address,
877
+ internal_address,
878
+ env,
879
+ router,
880
+ config,
881
+ servers,
882
+ )
883
+ self._main_address = main_address
884
+ if main_pool_pid:
885
+ self._watch_main_pool_task = asyncio.create_task(
886
+ self._watch_main_pool(main_pool_pid)
887
+ )
888
+ else:
889
+ self._watch_main_pool_task = None
890
+
891
+ async def _watch_main_pool(self, main_pool_pid: int):
892
+ main_process = psutil.Process(main_pool_pid)
893
+ while not self.stopped:
894
+ try:
895
+ await asyncio.to_thread(main_process.status)
896
+ await asyncio.sleep(0.1)
897
+ continue
898
+ except (psutil.NoSuchProcess, ProcessLookupError, asyncio.CancelledError):
899
+ # main pool died
900
+ break
901
+
902
+ if not self.stopped:
903
+ await self.stop()
904
+
905
+ async def notify_main_pool_to_destroy(
906
+ self, message: DestroyActorMessage
907
+ ): # pragma: no cover
908
+ await self.call(self._main_address, message)
909
+
910
+ async def notify_main_pool_to_create(self, message: CreateActorMessage):
911
+ reg_message = ControlMessage(
912
+ new_message_id(),
913
+ self.external_address,
914
+ ControlMessageType.add_sub_pool_actor,
915
+ (self.external_address, message.allocate_strategy, message),
916
+ )
917
+ await self.call(self._main_address, reg_message)
918
+
919
+ @implements(AbstractActorPool.create_actor)
920
+ async def create_actor(self, message: CreateActorMessage) -> ResultMessageType:
921
+ result = await super().create_actor(message)
922
+ if not message.from_main:
923
+ await self.notify_main_pool_to_create(message)
924
+ return result
925
+
926
+ @implements(AbstractActorPool.actor_ref)
927
+ async def actor_ref(self, message: ActorRefMessage) -> ResultMessageType:
928
+ result = await super().actor_ref(message)
929
+ if isinstance(result, ErrorMessage):
930
+ # need a new message id to call main actor
931
+ main_message = ActorRefMessage(
932
+ new_message_id(),
933
+ create_actor_ref(self._main_address, message.actor_ref.uid),
934
+ )
935
+ result = await self.call(self._main_address, main_message)
936
+ # rewrite to message_id of the original request
937
+ result.message_id = message.message_id
938
+ return result
939
+
940
+ @implements(AbstractActorPool.destroy_actor)
941
+ async def destroy_actor(self, message: DestroyActorMessage) -> ResultMessageType:
942
+ result = await super().destroy_actor(message)
943
+ if isinstance(result, ResultMessage) and not message.from_main:
944
+ # sync back to main actor pool
945
+ await self.notify_main_pool_to_destroy(message)
946
+ return result
947
+
948
+ @implements(AbstractActorPool.handle_control_command)
949
+ async def handle_control_command(
950
+ self, message: ControlMessage
951
+ ) -> ResultMessageType:
952
+ if message.control_message_type == ControlMessageType.sync_config:
953
+ self._main_address = message.address
954
+ return await super().handle_control_command(message)
955
+
956
+ @staticmethod
957
+ def _parse_config(config: dict, kw: dict) -> dict:
958
+ main_pool_pid = config.pop("main_pool_pid", None)
959
+ kw = AbstractActorPool._parse_config(config, kw)
960
+ pool_config: ActorPoolConfig = kw["config"]
961
+ main_process_index = pool_config.get_process_indexes()[0]
962
+ kw["main_address"] = pool_config.get_pool_config(main_process_index)[
963
+ "external_address"
964
+ ][0]
965
+ kw["main_pool_pid"] = main_pool_pid
966
+ return kw
967
+
968
+ async def stop(self):
969
+ await super().stop()
970
+ if self._watch_main_pool_task:
971
+ self._watch_main_pool_task.cancel()
972
+ await self._watch_main_pool_task
973
+
974
+
975
+ class MainActorPoolBase(ActorPoolBase):
976
+ __slots__ = (
977
+ "_allocated_actors",
978
+ "sub_actor_pool_manager",
979
+ "_auto_recover",
980
+ "_monitor_task",
981
+ "_on_process_down",
982
+ "_on_process_recover",
983
+ "_recover_events",
984
+ "_allocation_lock",
985
+ "sub_processes",
986
+ )
987
+
988
+ def __init__(
989
+ self,
990
+ process_index: int,
991
+ label: str,
992
+ external_address: str,
993
+ internal_address: str,
994
+ env: dict,
995
+ router: Router,
996
+ config: ActorPoolConfig,
997
+ servers: list[Server],
998
+ auto_recover: str | bool = "actor",
999
+ on_process_down: Callable[[MainActorPoolType, str], None] | None = None,
1000
+ on_process_recover: Callable[[MainActorPoolType, str], None] | None = None,
1001
+ ):
1002
+ super().__init__(
1003
+ process_index,
1004
+ label,
1005
+ external_address,
1006
+ internal_address,
1007
+ env,
1008
+ router,
1009
+ config,
1010
+ servers,
1011
+ )
1012
+
1013
+ # auto recovering
1014
+ self._auto_recover = auto_recover
1015
+ self._monitor_task: Optional[asyncio.Task] = None
1016
+ self._on_process_down = on_process_down
1017
+ self._on_process_recover = on_process_recover
1018
+ self._recover_events: dict[str, asyncio.Event] = dict()
1019
+
1020
+ # states
1021
+ self._allocated_actors: allocated_type = {
1022
+ addr: dict() for addr in self._config.get_external_addresses()
1023
+ }
1024
+ self._allocation_lock = threading.Lock()
1025
+
1026
+ self.sub_processes: dict[str, SubProcessHandle] = dict()
1027
+
1028
+ _process_index_gen = itertools.count()
1029
+
1030
+ @classmethod
1031
+ def process_index_gen(cls, address):
1032
+ # make sure different processes does not share process indexes
1033
+ pid = os.getpid()
1034
+ for idx in cls._process_index_gen:
1035
+ yield pid << 16 + idx
1036
+
1037
+ @property
1038
+ def _sub_processes(self):
1039
+ return self.sub_processes
1040
+
1041
+ @implements(AbstractActorPool.create_actor)
1042
+ async def create_actor(self, message: CreateActorMessage) -> ResultMessageType:
1043
+ with _ErrorProcessor(
1044
+ address=self.external_address,
1045
+ message_id=message.message_id,
1046
+ protocol=message.protocol,
1047
+ ) as processor:
1048
+ allocate_strategy = message.allocate_strategy
1049
+ with self._allocation_lock:
1050
+ # get allocated address according to corresponding strategy
1051
+ address = allocate_strategy.get_allocated_address(
1052
+ self._config, self._allocated_actors
1053
+ )
1054
+ # set placeholder to make sure this label is occupied
1055
+ self._allocated_actors[address][None] = (allocate_strategy, message)
1056
+ if address == self.external_address:
1057
+ # creating actor on main actor pool
1058
+ result = await super().create_actor(message)
1059
+ if isinstance(result, ResultMessage):
1060
+ self._allocated_actors[self.external_address][result.result] = (
1061
+ allocate_strategy,
1062
+ message,
1063
+ )
1064
+ processor.result = result
1065
+ else:
1066
+ # creating actor on sub actor pool
1067
+ # rewrite allocate strategy to AddressSpecified
1068
+ new_allocate_strategy = AddressSpecified(address)
1069
+ new_create_actor_message = CreateActorMessage(
1070
+ message.message_id,
1071
+ message.actor_cls,
1072
+ message.actor_id,
1073
+ message.args,
1074
+ message.kwargs,
1075
+ allocate_strategy=new_allocate_strategy,
1076
+ from_main=True,
1077
+ protocol=message.protocol,
1078
+ message_trace=message.message_trace,
1079
+ )
1080
+ result = await self.call(address, new_create_actor_message)
1081
+ if isinstance(result, ResultMessage):
1082
+ self._allocated_actors[address][result.result] = (
1083
+ allocate_strategy,
1084
+ new_create_actor_message,
1085
+ )
1086
+ processor.result = result
1087
+
1088
+ # revert placeholder
1089
+ self._allocated_actors[address].pop(None, None)
1090
+
1091
+ return processor.result
1092
+
1093
+ @implements(AbstractActorPool.has_actor)
1094
+ async def has_actor(self, message: HasActorMessage) -> ResultMessage:
1095
+ actor_ref = message.actor_ref
1096
+ # lookup allocated
1097
+ for address, item in self._allocated_actors.items():
1098
+ ref = create_actor_ref(address, to_binary(actor_ref.uid))
1099
+ if ref in item:
1100
+ return ResultMessage(
1101
+ message.message_id, True, protocol=message.protocol
1102
+ )
1103
+
1104
+ return ResultMessage(message.message_id, False, protocol=message.protocol)
1105
+
1106
+ @implements(AbstractActorPool.destroy_actor)
1107
+ async def destroy_actor(self, message: DestroyActorMessage) -> ResultMessageType:
1108
+ actor_ref_message = ActorRefMessage(
1109
+ message.message_id, message.actor_ref, protocol=message.protocol
1110
+ )
1111
+ result = await self.actor_ref(actor_ref_message)
1112
+ if not isinstance(result, ResultMessage):
1113
+ return result
1114
+ real_actor_ref = result.result
1115
+ if real_actor_ref.address == self.external_address:
1116
+ result = await super().destroy_actor(message)
1117
+ if result.message_type == MessageType.error:
1118
+ return result
1119
+ del self._allocated_actors[self.external_address][real_actor_ref]
1120
+ return ResultMessage(
1121
+ message.message_id, real_actor_ref.uid, protocol=message.protocol
1122
+ )
1123
+ # remove allocated actor ref
1124
+ self._allocated_actors[real_actor_ref.address].pop(real_actor_ref, None)
1125
+ new_destroy_message = DestroyActorMessage(
1126
+ message.message_id,
1127
+ real_actor_ref,
1128
+ from_main=True,
1129
+ protocol=message.protocol,
1130
+ )
1131
+ return await self.call(real_actor_ref.address, new_destroy_message)
1132
+
1133
+ @implements(AbstractActorPool.send)
1134
+ async def send(self, message: SendMessage) -> ResultMessageType:
1135
+ if message.actor_ref.uid in self._actors:
1136
+ return await super().send(message)
1137
+ actor_ref_message = ActorRefMessage(
1138
+ message.message_id, message.actor_ref, protocol=message.protocol
1139
+ )
1140
+ result = await self.actor_ref(actor_ref_message)
1141
+ if not isinstance(result, ResultMessage):
1142
+ return result
1143
+ actor_ref = result.result
1144
+ new_send_message = SendMessage(
1145
+ message.message_id,
1146
+ actor_ref,
1147
+ message.content,
1148
+ protocol=message.protocol,
1149
+ message_trace=message.message_trace,
1150
+ )
1151
+ return await self.call(actor_ref.address, new_send_message)
1152
+
1153
+ @implements(AbstractActorPool.tell)
1154
+ async def tell(self, message: TellMessage) -> ResultMessageType:
1155
+ if message.actor_ref.uid in self._actors:
1156
+ return await super().tell(message)
1157
+ actor_ref_message = ActorRefMessage(
1158
+ message.message_id, message.actor_ref, protocol=message.protocol
1159
+ )
1160
+ result = await self.actor_ref(actor_ref_message)
1161
+ if not isinstance(result, ResultMessage):
1162
+ return result
1163
+ actor_ref = result.result
1164
+ new_tell_message = TellMessage(
1165
+ message.message_id,
1166
+ actor_ref,
1167
+ message.content,
1168
+ protocol=message.protocol,
1169
+ message_trace=message.message_trace,
1170
+ )
1171
+ return await self.call(actor_ref.address, new_tell_message)
1172
+
1173
+ @implements(AbstractActorPool.actor_ref)
1174
+ async def actor_ref(self, message: ActorRefMessage) -> ResultMessageType:
1175
+ actor_ref = message.actor_ref
1176
+ actor_ref.uid = to_binary(actor_ref.uid)
1177
+ if actor_ref.address == self.external_address and actor_ref.uid in self._actors:
1178
+ actor_ref.proxy_addresses = self._router.get_proxies(actor_ref.address)
1179
+ return ResultMessage(
1180
+ message.message_id, actor_ref, protocol=message.protocol
1181
+ )
1182
+
1183
+ # lookup allocated
1184
+ for address, item in self._allocated_actors.items():
1185
+ ref = create_actor_ref(address, actor_ref.uid)
1186
+ if ref in item:
1187
+ ref.proxy_addresses = self._router.get_proxies(ref.address)
1188
+ return ResultMessage(message.message_id, ref, protocol=message.protocol)
1189
+
1190
+ with _ErrorProcessor(
1191
+ self.external_address, message.message_id, protocol=message.protocol
1192
+ ) as processor:
1193
+ raise ActorNotExist(
1194
+ f"Actor {actor_ref.uid} does not exist in {actor_ref.address}"
1195
+ )
1196
+
1197
+ return processor.result
1198
+
1199
+ @implements(AbstractActorPool.cancel)
1200
+ async def cancel(self, message: CancelMessage) -> ResultMessageType:
1201
+ if message.address == self.external_address:
1202
+ # local message
1203
+ return await super().cancel(message)
1204
+ # redirect to sub pool
1205
+ return await self.call(message.address, message)
1206
+
1207
+ @implements(AbstractActorPool.handle_control_command)
1208
+ async def handle_control_command(
1209
+ self, message: ControlMessage
1210
+ ) -> ResultMessageType:
1211
+ with _ErrorProcessor(
1212
+ self.external_address, message.message_id, message.protocol
1213
+ ) as processor:
1214
+ if message.address == self.external_address:
1215
+ if message.control_message_type == ControlMessageType.sync_config:
1216
+ # sync config, need to notify all sub pools
1217
+ tasks = []
1218
+ for addr in self.sub_processes:
1219
+ control_message = ControlMessage(
1220
+ new_message_id(),
1221
+ message.address,
1222
+ message.control_message_type,
1223
+ message.content,
1224
+ protocol=message.protocol,
1225
+ message_trace=message.message_trace,
1226
+ )
1227
+ tasks.append(
1228
+ asyncio.create_task(self.call(addr, control_message))
1229
+ )
1230
+ # call super
1231
+ task = asyncio.create_task(super().handle_control_command(message))
1232
+ tasks.append(task)
1233
+ await asyncio.gather(*tasks)
1234
+ processor.result = await task
1235
+ else:
1236
+ processor.result = await super().handle_control_command(message)
1237
+ elif message.control_message_type == ControlMessageType.stop:
1238
+ timeout, force = (
1239
+ message.content if message.content is not None else (None, False)
1240
+ )
1241
+ await self.stop_sub_pool(
1242
+ message.address,
1243
+ self.sub_processes[message.address],
1244
+ timeout=timeout,
1245
+ force=force,
1246
+ )
1247
+ processor.result = ResultMessage(
1248
+ message.message_id, True, protocol=message.protocol
1249
+ )
1250
+ elif message.control_message_type == ControlMessageType.wait_pool_recovered:
1251
+ if self._auto_recover and message.address not in self._recover_events:
1252
+ self._recover_events[message.address] = asyncio.Event()
1253
+
1254
+ event = self._recover_events.get(message.address, None)
1255
+ if event is not None:
1256
+ await event.wait()
1257
+ processor.result = ResultMessage(
1258
+ message.message_id, True, protocol=message.protocol
1259
+ )
1260
+ elif message.control_message_type == ControlMessageType.add_sub_pool_actor:
1261
+ address, allocate_strategy, create_message = message.content
1262
+ create_message.from_main = True
1263
+ ref = create_actor_ref(address, to_binary(create_message.actor_id))
1264
+ self._allocated_actors[address][ref] = (
1265
+ allocate_strategy,
1266
+ create_message,
1267
+ )
1268
+ processor.result = ResultMessage(
1269
+ message.message_id, True, protocol=message.protocol
1270
+ )
1271
+ else:
1272
+ processor.result = await self.call(message.address, message)
1273
+ return processor.result
1274
+
1275
+ @staticmethod
1276
+ def _parse_config(config: dict, kw: dict) -> dict:
1277
+ kw["auto_recover"] = config.pop("auto_recover", "actor")
1278
+ kw["on_process_down"] = config.pop("on_process_down", None)
1279
+ kw["on_process_recover"] = config.pop("on_process_recover", None)
1280
+ kw = AbstractActorPool._parse_config(config, kw)
1281
+ return kw
1282
+
1283
+ @classmethod
1284
+ @implements(AbstractActorPool.create)
1285
+ async def create(cls, config: dict) -> MainActorPoolType:
1286
+ config = config.copy()
1287
+ actor_pool_config: ActorPoolConfig = config.get("actor_pool_config") # type: ignore
1288
+ if "process_index" not in config:
1289
+ config["process_index"] = actor_pool_config.get_process_indexes()[0]
1290
+ curr_process_index = config.get("process_index")
1291
+ old_config_addresses = set(actor_pool_config.get_external_addresses())
1292
+
1293
+ tasks = []
1294
+ subpool_process_idxes = []
1295
+ # create sub actor pools
1296
+ n_sub_pool = actor_pool_config.n_pool - 1
1297
+ if n_sub_pool > 0:
1298
+ process_indexes = actor_pool_config.get_process_indexes()
1299
+ for process_index in process_indexes:
1300
+ if process_index == curr_process_index:
1301
+ continue
1302
+ create_pool_task = asyncio.create_task(
1303
+ cls.start_sub_pool(actor_pool_config, process_index)
1304
+ )
1305
+ await asyncio.sleep(0)
1306
+ # await create_pool_task
1307
+ tasks.append(create_pool_task)
1308
+ subpool_process_idxes.append(process_index)
1309
+
1310
+ processes, ext_addresses = await cls.wait_sub_pools_ready(tasks)
1311
+ if ext_addresses:
1312
+ for process_index, ext_address in zip(subpool_process_idxes, ext_addresses):
1313
+ actor_pool_config.reset_pool_external_address(
1314
+ process_index, ext_address
1315
+ )
1316
+
1317
+ # create main actor pool
1318
+ pool: MainActorPoolType = await super().create(config)
1319
+ addresses = actor_pool_config.get_external_addresses()[1:]
1320
+
1321
+ assert len(addresses) == len(
1322
+ processes
1323
+ ), f"addresses {addresses}, processes {processes}"
1324
+ for addr, proc in zip(addresses, processes):
1325
+ pool.attach_sub_process(addr, proc)
1326
+
1327
+ new_config_addresses = set(actor_pool_config.get_external_addresses())
1328
+ if old_config_addresses != new_config_addresses:
1329
+ control_message = ControlMessage(
1330
+ message_id=new_message_id(),
1331
+ address=pool.external_address,
1332
+ control_message_type=ControlMessageType.sync_config,
1333
+ content=actor_pool_config,
1334
+ )
1335
+ await pool.handle_control_command(control_message)
1336
+
1337
+ return pool
1338
+
1339
+ async def start_monitor(self):
1340
+ # Only start monitor if there are sub processes to monitor
1341
+ # This prevents hanging when n_process=0
1342
+ if self._monitor_task is None and self.sub_processes:
1343
+ self._monitor_task = asyncio.create_task(self.monitor_sub_pools())
1344
+ return self._monitor_task
1345
+
1346
+ @implements(AbstractActorPool.stop)
1347
+ async def stop(self):
1348
+ global_router = Router.get_instance()
1349
+ if global_router is not None:
1350
+ global_router.remove_router(self._router)
1351
+
1352
+ # turn off auto recover to avoid errors
1353
+ self._auto_recover = False
1354
+ self._stopped.set()
1355
+ if self._monitor_task and not self._monitor_task.done():
1356
+ # Cancel the monitor task to ensure it exits immediately
1357
+ self._monitor_task.cancel()
1358
+ try:
1359
+ await self._monitor_task
1360
+ except asyncio.CancelledError:
1361
+ pass # Expected when cancelling the task
1362
+ self._monitor_task = None
1363
+ await self.stop_sub_pools()
1364
+ await super().stop()
1365
+
1366
+ @classmethod
1367
+ @abstractmethod
1368
+ async def start_sub_pool(
1369
+ cls,
1370
+ actor_pool_config: ActorPoolConfig,
1371
+ process_index: int,
1372
+ start_python: str | None = None,
1373
+ ):
1374
+ """Start a sub actor pool"""
1375
+
1376
+ @classmethod
1377
+ @abstractmethod
1378
+ async def wait_sub_pools_ready(cls, create_pool_tasks: list[asyncio.Task]):
1379
+ """Wait all sub pools ready"""
1380
+
1381
+ def attach_sub_process(self, external_address: str, process: SubProcessHandle):
1382
+ self.sub_processes[external_address] = process
1383
+
1384
+ async def stop_sub_pools(self):
1385
+ to_stop_processes: dict[str, SubProcessHandle] = dict() # type: ignore
1386
+ for address, process in self.sub_processes.items():
1387
+ if not await self.is_sub_pool_alive(process):
1388
+ continue
1389
+ to_stop_processes[address] = process
1390
+
1391
+ tasks = []
1392
+ for address, process in to_stop_processes.items():
1393
+ tasks.append(self.stop_sub_pool(address, process))
1394
+ await asyncio.gather(*tasks)
1395
+
1396
+ async def stop_sub_pool(
1397
+ self,
1398
+ address: str,
1399
+ process: SubProcessHandle,
1400
+ timeout: float | None = None,
1401
+ force: bool = False,
1402
+ ):
1403
+ if force:
1404
+ await self.kill_sub_pool(process, force=True)
1405
+ return
1406
+
1407
+ stop_message = ControlMessage(
1408
+ new_message_id(),
1409
+ address,
1410
+ ControlMessageType.stop,
1411
+ None,
1412
+ protocol=DEFAULT_PROTOCOL,
1413
+ )
1414
+ try:
1415
+ if timeout is None:
1416
+ # Use a short timeout for graceful shutdown to avoid hanging
1417
+ timeout = 2.0
1418
+
1419
+ call = asyncio.create_task(self.call(address, stop_message))
1420
+ try:
1421
+ await asyncio.wait_for(call, timeout)
1422
+ except (futures.TimeoutError, asyncio.TimeoutError):
1423
+ force = True
1424
+ except (ConnectionError, ServerClosed):
1425
+ # process dead maybe, ignore it
1426
+ force = True
1427
+ # kill process
1428
+ await self.kill_sub_pool(process, force=force)
1429
+
1430
+ @abstractmethod
1431
+ async def kill_sub_pool(self, process: SubProcessHandle, force: bool = False):
1432
+ """Kill a sub actor pool"""
1433
+
1434
+ @abstractmethod
1435
+ async def is_sub_pool_alive(self, process: SubProcessHandle):
1436
+ """
1437
+ Check whether sub pool process is alive
1438
+ Parameters
1439
+ ----------
1440
+ process : SubProcessHandle
1441
+ sub pool process handle
1442
+ Returns
1443
+ -------
1444
+ bool
1445
+ """
1446
+
1447
+ @abstractmethod
1448
+ def recover_sub_pool(self, address):
1449
+ """Recover a sub actor pool"""
1450
+
1451
+ def process_sub_pool_lost(self, address: str):
1452
+ if self._auto_recover in (False, "process"):
1453
+ # process down, when not auto_recover
1454
+ # or only recover process, remove all created actors
1455
+ self._allocated_actors[address] = dict()
1456
+
1457
+ async def monitor_sub_pools(self):
1458
+ try:
1459
+ while not self._stopped.is_set():
1460
+ # Copy sub_processes to avoid changes during recover.
1461
+ for address, process in list(self.sub_processes.items()):
1462
+ try:
1463
+ recover_events_discovered = address in self._recover_events
1464
+ if not await self.is_sub_pool_alive(
1465
+ process
1466
+ ): # pragma: no cover
1467
+ if self._on_process_down is not None:
1468
+ self._on_process_down(self, address)
1469
+ self.process_sub_pool_lost(address)
1470
+ if self._auto_recover:
1471
+ await self.recover_sub_pool(address)
1472
+ if self._on_process_recover is not None:
1473
+ self._on_process_recover(self, address)
1474
+ if recover_events_discovered:
1475
+ event = self._recover_events.pop(address)
1476
+ event.set()
1477
+ except asyncio.CancelledError:
1478
+ raise
1479
+ except RuntimeError as ex: # pragma: no cover
1480
+ if "cannot schedule new futures" not in str(ex):
1481
+ # to silence log when process exit, otherwise it
1482
+ # will raise "RuntimeError: cannot schedule new futures
1483
+ # after interpreter shutdown".
1484
+ logger.exception("Monitor sub pool %s failed", address)
1485
+ except Exception:
1486
+ # log the exception instead of stop monitoring the
1487
+ # sub pool silently.
1488
+ logger.exception("Monitor sub pool %s failed", address)
1489
+
1490
+ # check every half second
1491
+ await asyncio.sleep(0.5)
1492
+ except asyncio.CancelledError: # pragma: no cover
1493
+ # cancelled
1494
+ return
1495
+
1496
+ @classmethod
1497
+ @abstractmethod
1498
+ def get_external_addresses(
1499
+ cls,
1500
+ address: str,
1501
+ n_process: int | None = None,
1502
+ ports: list[int] | None = None,
1503
+ schemes: list[Optional[str]] | None = None,
1504
+ ):
1505
+ """Returns external addresses for n pool processes"""
1506
+
1507
+ @classmethod
1508
+ @abstractmethod
1509
+ def gen_internal_address(
1510
+ cls, process_index: int, external_address: str | None = None
1511
+ ) -> str | None:
1512
+ """Returns internal address for pool of specified process index"""
1513
+
1514
+
1515
+ async def create_actor_pool(
1516
+ address: str,
1517
+ *,
1518
+ pool_cls: Type[MainActorPoolType] | None = None,
1519
+ n_process: int | None = None,
1520
+ labels: list[str] | None = None,
1521
+ ports: list[int] | None = None,
1522
+ envs: list[dict] | None = None,
1523
+ external_address_schemes: list[Optional[str]] | None = None,
1524
+ enable_internal_addresses: list[bool] | None = None,
1525
+ auto_recover: str | bool = "actor",
1526
+ modules: list[str] | None = None,
1527
+ suspend_sigint: bool | None = None,
1528
+ use_uvloop: str | bool = "auto",
1529
+ logging_conf: dict | None = None,
1530
+ proxy_conf: dict | None = None,
1531
+ on_process_down: Callable[[MainActorPoolType, str], None] | None = None,
1532
+ on_process_recover: Callable[[MainActorPoolType, str], None] | None = None,
1533
+ extra_conf: dict | None = None,
1534
+ **kwargs,
1535
+ ) -> MainActorPoolType:
1536
+ if n_process is None:
1537
+ n_process = multiprocessing.cpu_count()
1538
+ if labels and len(labels) != n_process + 1:
1539
+ raise ValueError(
1540
+ f"`labels` should be of size {n_process + 1}, got {len(labels)}"
1541
+ )
1542
+ if envs and len(envs) != n_process:
1543
+ raise ValueError(f"`envs` should be of size {n_process}, got {len(envs)}")
1544
+ if external_address_schemes and len(external_address_schemes) != n_process + 1:
1545
+ raise ValueError(
1546
+ f"`external_address_schemes` should be of size {n_process + 1}, "
1547
+ f"got {len(external_address_schemes)}"
1548
+ )
1549
+ if enable_internal_addresses and len(enable_internal_addresses) != n_process + 1:
1550
+ raise ValueError(
1551
+ f"`enable_internal_addresses` should be of size {n_process + 1}, "
1552
+ f"got {len(enable_internal_addresses)}"
1553
+ )
1554
+ elif not enable_internal_addresses:
1555
+ enable_internal_addresses = [True] * (n_process + 1)
1556
+ if auto_recover is True:
1557
+ auto_recover = "actor"
1558
+ if auto_recover not in ("actor", "process", False):
1559
+ raise ValueError(
1560
+ f'`auto_recover` should be one of "actor", "process", '
1561
+ f"True or False, got {auto_recover}"
1562
+ )
1563
+ if use_uvloop == "auto":
1564
+ try:
1565
+ import uvloop # noqa: F401 # pylint: disable=unused-variable
1566
+
1567
+ use_uvloop = True
1568
+ except ImportError:
1569
+ use_uvloop = False
1570
+
1571
+ assert pool_cls is not None
1572
+ external_addresses = pool_cls.get_external_addresses(
1573
+ address, n_process=n_process, ports=ports, schemes=external_address_schemes
1574
+ )
1575
+ actor_pool_config = ActorPoolConfig()
1576
+ actor_pool_config.add_metric_configs(kwargs.get("metrics", {}))
1577
+ # add proxy config
1578
+ actor_pool_config.add_proxy_config(proxy_conf)
1579
+ # add main config
1580
+ process_index_gen = pool_cls.process_index_gen(address)
1581
+ main_process_index = next(process_index_gen)
1582
+ main_internal_address = (
1583
+ pool_cls.gen_internal_address(main_process_index, external_addresses[0])
1584
+ if enable_internal_addresses[0]
1585
+ else None
1586
+ )
1587
+ actor_pool_config.add_pool_conf(
1588
+ main_process_index,
1589
+ labels[0] if labels else None,
1590
+ main_internal_address,
1591
+ external_addresses[0],
1592
+ modules=modules,
1593
+ suspend_sigint=suspend_sigint,
1594
+ use_uvloop=use_uvloop, # type: ignore
1595
+ logging_conf=logging_conf,
1596
+ kwargs=kwargs,
1597
+ )
1598
+ # add sub configs
1599
+ for i in range(n_process):
1600
+ sub_process_index = next(process_index_gen)
1601
+ internal_address = (
1602
+ pool_cls.gen_internal_address(sub_process_index, external_addresses[i + 1])
1603
+ if enable_internal_addresses[i + 1]
1604
+ else None
1605
+ )
1606
+ actor_pool_config.add_pool_conf(
1607
+ sub_process_index,
1608
+ labels[i + 1] if labels else None,
1609
+ internal_address,
1610
+ external_addresses[i + 1],
1611
+ env=envs[i] if envs else None,
1612
+ modules=modules,
1613
+ suspend_sigint=suspend_sigint,
1614
+ use_uvloop=use_uvloop, # type: ignore
1615
+ logging_conf=logging_conf,
1616
+ kwargs=kwargs,
1617
+ )
1618
+ actor_pool_config.add_comm_config(extra_conf)
1619
+
1620
+ pool: MainActorPoolType = await pool_cls.create(
1621
+ {
1622
+ "actor_pool_config": actor_pool_config,
1623
+ "process_index": main_process_index,
1624
+ "auto_recover": auto_recover,
1625
+ "on_process_down": on_process_down,
1626
+ "on_process_recover": on_process_recover,
1627
+ }
1628
+ )
1629
+ await pool.start()
1630
+ return pool