cocoindex 0.2.3__cp311-abi3-macosx_10_12_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/__init__.py +92 -0
- cocoindex/_engine.abi3.so +0 -0
- cocoindex/auth_registry.py +51 -0
- cocoindex/cli.py +697 -0
- cocoindex/convert.py +621 -0
- cocoindex/flow.py +1205 -0
- cocoindex/functions.py +357 -0
- cocoindex/index.py +29 -0
- cocoindex/lib.py +32 -0
- cocoindex/llm.py +46 -0
- cocoindex/op.py +628 -0
- cocoindex/py.typed +0 -0
- cocoindex/runtime.py +37 -0
- cocoindex/setting.py +181 -0
- cocoindex/setup.py +92 -0
- cocoindex/sources.py +102 -0
- cocoindex/subprocess_exec.py +279 -0
- cocoindex/targets.py +135 -0
- cocoindex/tests/__init__.py +0 -0
- cocoindex/tests/conftest.py +38 -0
- cocoindex/tests/test_convert.py +1543 -0
- cocoindex/tests/test_optional_database.py +249 -0
- cocoindex/tests/test_transform_flow.py +207 -0
- cocoindex/tests/test_typing.py +429 -0
- cocoindex/tests/test_validation.py +134 -0
- cocoindex/typing.py +473 -0
- cocoindex/user_app_loader.py +51 -0
- cocoindex/utils.py +20 -0
- cocoindex/validation.py +104 -0
- cocoindex-0.2.3.dist-info/METADATA +262 -0
- cocoindex-0.2.3.dist-info/RECORD +34 -0
- cocoindex-0.2.3.dist-info/WHEEL +4 -0
- cocoindex-0.2.3.dist-info/entry_points.txt +2 -0
- cocoindex-0.2.3.dist-info/licenses/LICENSE +201 -0
cocoindex/setting.py
ADDED
@@ -0,0 +1,181 @@
|
|
1
|
+
"""
|
2
|
+
Data types for settings of the cocoindex library.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
|
7
|
+
from typing import Callable, Self, Any, overload
|
8
|
+
from dataclasses import dataclass
|
9
|
+
from .validation import validate_app_namespace_name
|
10
|
+
|
11
|
+
_app_namespace: str = ""
|
12
|
+
|
13
|
+
|
14
|
+
def get_app_namespace(*, trailing_delimiter: str | None = None) -> str:
|
15
|
+
"""Get the application namespace. Append the `trailing_delimiter` if not empty."""
|
16
|
+
if _app_namespace == "" or trailing_delimiter is None:
|
17
|
+
return _app_namespace
|
18
|
+
return f"{_app_namespace}{trailing_delimiter}"
|
19
|
+
|
20
|
+
|
21
|
+
def split_app_namespace(full_name: str, delimiter: str) -> tuple[str, str]:
|
22
|
+
"""Split the full name into the application namespace and the rest."""
|
23
|
+
parts = full_name.split(delimiter, 1)
|
24
|
+
if len(parts) == 1:
|
25
|
+
return "", parts[0]
|
26
|
+
return (parts[0], parts[1])
|
27
|
+
|
28
|
+
|
29
|
+
def set_app_namespace(app_namespace: str) -> None:
|
30
|
+
"""Set the application namespace."""
|
31
|
+
if app_namespace:
|
32
|
+
validate_app_namespace_name(app_namespace)
|
33
|
+
global _app_namespace # pylint: disable=global-statement
|
34
|
+
_app_namespace = app_namespace
|
35
|
+
|
36
|
+
|
37
|
+
@dataclass
|
38
|
+
class DatabaseConnectionSpec:
|
39
|
+
"""
|
40
|
+
Connection spec for relational database.
|
41
|
+
Used by both internal and target storage.
|
42
|
+
"""
|
43
|
+
|
44
|
+
url: str
|
45
|
+
user: str | None = None
|
46
|
+
password: str | None = None
|
47
|
+
max_connections: int = 25
|
48
|
+
min_connections: int = 5
|
49
|
+
|
50
|
+
|
51
|
+
@dataclass
|
52
|
+
class GlobalExecutionOptions:
|
53
|
+
"""Global execution options."""
|
54
|
+
|
55
|
+
# The maximum number of concurrent inflight requests, shared among all sources from all flows.
|
56
|
+
source_max_inflight_rows: int | None = 1024
|
57
|
+
source_max_inflight_bytes: int | None = None
|
58
|
+
|
59
|
+
|
60
|
+
def _load_field(
|
61
|
+
target: dict[str, Any],
|
62
|
+
name: str,
|
63
|
+
env_name: str,
|
64
|
+
required: bool = False,
|
65
|
+
parse: Callable[[str], Any] | None = None,
|
66
|
+
) -> None:
|
67
|
+
value = os.getenv(env_name)
|
68
|
+
if value is None:
|
69
|
+
if required:
|
70
|
+
raise ValueError(f"{env_name} is not set")
|
71
|
+
else:
|
72
|
+
if parse is None:
|
73
|
+
target[name] = value
|
74
|
+
else:
|
75
|
+
try:
|
76
|
+
target[name] = parse(value)
|
77
|
+
except Exception as e:
|
78
|
+
raise ValueError(
|
79
|
+
f"failed to parse environment variable {env_name}: {value}"
|
80
|
+
) from e
|
81
|
+
|
82
|
+
|
83
|
+
@dataclass
|
84
|
+
class Settings:
|
85
|
+
"""Settings for the cocoindex library."""
|
86
|
+
|
87
|
+
database: DatabaseConnectionSpec | None = None
|
88
|
+
app_namespace: str = ""
|
89
|
+
global_execution_options: GlobalExecutionOptions | None = None
|
90
|
+
|
91
|
+
@classmethod
|
92
|
+
def from_env(cls) -> Self:
|
93
|
+
"""Load settings from environment variables."""
|
94
|
+
|
95
|
+
database_url = os.getenv("COCOINDEX_DATABASE_URL")
|
96
|
+
if database_url is not None:
|
97
|
+
db_kwargs: dict[str, Any] = dict()
|
98
|
+
_load_field(db_kwargs, "url", "COCOINDEX_DATABASE_URL", required=True)
|
99
|
+
_load_field(db_kwargs, "user", "COCOINDEX_DATABASE_USER")
|
100
|
+
_load_field(db_kwargs, "password", "COCOINDEX_DATABASE_PASSWORD")
|
101
|
+
_load_field(
|
102
|
+
db_kwargs,
|
103
|
+
"max_connections",
|
104
|
+
"COCOINDEX_DATABASE_MAX_CONNECTIONS",
|
105
|
+
parse=int,
|
106
|
+
)
|
107
|
+
_load_field(
|
108
|
+
db_kwargs,
|
109
|
+
"min_connections",
|
110
|
+
"COCOINDEX_DATABASE_MIN_CONNECTIONS",
|
111
|
+
parse=int,
|
112
|
+
)
|
113
|
+
database = DatabaseConnectionSpec(**db_kwargs)
|
114
|
+
else:
|
115
|
+
database = None
|
116
|
+
|
117
|
+
exec_kwargs: dict[str, Any] = dict()
|
118
|
+
_load_field(
|
119
|
+
exec_kwargs,
|
120
|
+
"source_max_inflight_rows",
|
121
|
+
"COCOINDEX_SOURCE_MAX_INFLIGHT_ROWS",
|
122
|
+
parse=int,
|
123
|
+
)
|
124
|
+
_load_field(
|
125
|
+
exec_kwargs,
|
126
|
+
"source_max_inflight_bytes",
|
127
|
+
"COCOINDEX_SOURCE_MAX_INFLIGHT_BYTES",
|
128
|
+
parse=int,
|
129
|
+
)
|
130
|
+
global_execution_options = GlobalExecutionOptions(**exec_kwargs)
|
131
|
+
|
132
|
+
app_namespace = os.getenv("COCOINDEX_APP_NAMESPACE", "")
|
133
|
+
|
134
|
+
return cls(
|
135
|
+
database=database,
|
136
|
+
app_namespace=app_namespace,
|
137
|
+
global_execution_options=global_execution_options,
|
138
|
+
)
|
139
|
+
|
140
|
+
|
141
|
+
@dataclass
|
142
|
+
class ServerSettings:
|
143
|
+
"""Settings for the cocoindex server."""
|
144
|
+
|
145
|
+
# The address to bind the server to.
|
146
|
+
address: str = "127.0.0.1:49344"
|
147
|
+
|
148
|
+
# The origins of the clients (e.g. CocoInsight UI) to allow CORS from.
|
149
|
+
cors_origins: list[str] | None = None
|
150
|
+
|
151
|
+
@classmethod
|
152
|
+
def from_env(cls) -> Self:
|
153
|
+
"""Load settings from environment variables."""
|
154
|
+
kwargs: dict[str, Any] = dict()
|
155
|
+
_load_field(kwargs, "address", "COCOINDEX_SERVER_ADDRESS")
|
156
|
+
_load_field(
|
157
|
+
kwargs,
|
158
|
+
"cors_origins",
|
159
|
+
"COCOINDEX_SERVER_CORS_ORIGINS",
|
160
|
+
parse=ServerSettings.parse_cors_origins,
|
161
|
+
)
|
162
|
+
return cls(**kwargs)
|
163
|
+
|
164
|
+
@overload
|
165
|
+
@staticmethod
|
166
|
+
def parse_cors_origins(s: str) -> list[str]: ...
|
167
|
+
|
168
|
+
@overload
|
169
|
+
@staticmethod
|
170
|
+
def parse_cors_origins(s: str | None) -> list[str] | None: ...
|
171
|
+
|
172
|
+
@staticmethod
|
173
|
+
def parse_cors_origins(s: str | None) -> list[str] | None:
|
174
|
+
"""
|
175
|
+
Parse the CORS origins from a string.
|
176
|
+
"""
|
177
|
+
return (
|
178
|
+
[o for e in s.split(",") if (o := e.strip()) != ""]
|
179
|
+
if s is not None
|
180
|
+
else None
|
181
|
+
)
|
cocoindex/setup.py
ADDED
@@ -0,0 +1,92 @@
|
|
1
|
+
"""
|
2
|
+
This module provides APIs to manage the setup of flows.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from . import setting
|
6
|
+
from . import _engine # type: ignore
|
7
|
+
from .runtime import execution_context
|
8
|
+
|
9
|
+
|
10
|
+
class SetupChangeBundle:
|
11
|
+
"""
|
12
|
+
This class represents a bundle of setup changes.
|
13
|
+
"""
|
14
|
+
|
15
|
+
_engine_bundle: _engine.SetupChangeBundle
|
16
|
+
|
17
|
+
def __init__(self, _engine_bundle: _engine.SetupChangeBundle):
|
18
|
+
self._engine_bundle = _engine_bundle
|
19
|
+
|
20
|
+
def __str__(self) -> str:
|
21
|
+
desc, _ = execution_context.run(self._engine_bundle.describe_async())
|
22
|
+
return desc # type: ignore
|
23
|
+
|
24
|
+
def __repr__(self) -> str:
|
25
|
+
return self.__str__()
|
26
|
+
|
27
|
+
def apply(self, report_to_stdout: bool = False) -> None:
|
28
|
+
"""
|
29
|
+
Apply the setup changes.
|
30
|
+
"""
|
31
|
+
execution_context.run(self.apply_async(report_to_stdout=report_to_stdout))
|
32
|
+
|
33
|
+
async def apply_async(self, report_to_stdout: bool = False) -> None:
|
34
|
+
"""
|
35
|
+
Apply the setup changes. Async version of `apply`.
|
36
|
+
"""
|
37
|
+
await self._engine_bundle.apply_async(report_to_stdout=report_to_stdout)
|
38
|
+
|
39
|
+
def describe(self) -> tuple[str, bool]:
|
40
|
+
"""
|
41
|
+
Describe the setup changes.
|
42
|
+
"""
|
43
|
+
return execution_context.run(self.describe_async()) # type: ignore
|
44
|
+
|
45
|
+
async def describe_async(self) -> tuple[str, bool]:
|
46
|
+
"""
|
47
|
+
Describe the setup changes. Async version of `describe`.
|
48
|
+
"""
|
49
|
+
return await self._engine_bundle.describe_async() # type: ignore
|
50
|
+
|
51
|
+
def describe_and_apply(self, report_to_stdout: bool = False) -> None:
|
52
|
+
"""
|
53
|
+
Describe the setup changes and apply them if `report_to_stdout` is True.
|
54
|
+
Silently apply setup changes otherwise.
|
55
|
+
"""
|
56
|
+
execution_context.run(
|
57
|
+
self.describe_and_apply_async(report_to_stdout=report_to_stdout)
|
58
|
+
)
|
59
|
+
|
60
|
+
async def describe_and_apply_async(self, *, report_to_stdout: bool = False) -> None:
|
61
|
+
"""
|
62
|
+
Describe the setup changes and apply them if `report_to_stdout` is True.
|
63
|
+
Silently apply setup changes otherwise. Async version of `describe_and_apply`.
|
64
|
+
"""
|
65
|
+
if report_to_stdout:
|
66
|
+
desc, is_up_to_date = await self.describe_async()
|
67
|
+
print("Setup status:\n")
|
68
|
+
print(desc)
|
69
|
+
if is_up_to_date:
|
70
|
+
print("No setup changes to apply.")
|
71
|
+
return
|
72
|
+
await self.apply_async(report_to_stdout=report_to_stdout)
|
73
|
+
|
74
|
+
|
75
|
+
def flow_names_with_setup() -> list[str]:
|
76
|
+
"""
|
77
|
+
Get the names of all flows that have been setup.
|
78
|
+
"""
|
79
|
+
return execution_context.run(flow_names_with_setup_async()) # type: ignore
|
80
|
+
|
81
|
+
|
82
|
+
async def flow_names_with_setup_async() -> list[str]:
|
83
|
+
"""
|
84
|
+
Get the names of all flows that have been setup. Async version of `flow_names_with_setup`.
|
85
|
+
"""
|
86
|
+
result = []
|
87
|
+
all_flow_names = await _engine.flow_names_with_setup_async()
|
88
|
+
for name in all_flow_names:
|
89
|
+
app_namespace, name = setting.split_app_namespace(name, ".")
|
90
|
+
if app_namespace == setting.get_app_namespace():
|
91
|
+
result.append(name)
|
92
|
+
return result
|
cocoindex/sources.py
ADDED
@@ -0,0 +1,102 @@
|
|
1
|
+
"""All builtin sources."""
|
2
|
+
|
3
|
+
from . import op
|
4
|
+
from .auth_registry import TransientAuthEntryReference
|
5
|
+
from .setting import DatabaseConnectionSpec
|
6
|
+
from dataclasses import dataclass
|
7
|
+
import datetime
|
8
|
+
|
9
|
+
|
10
|
+
class LocalFile(op.SourceSpec):
|
11
|
+
"""Import data from local file system."""
|
12
|
+
|
13
|
+
_op_category = op.OpCategory.SOURCE
|
14
|
+
|
15
|
+
path: str
|
16
|
+
binary: bool = False
|
17
|
+
|
18
|
+
# If provided, only files matching these patterns will be included.
|
19
|
+
# See https://docs.rs/globset/latest/globset/index.html#syntax for the syntax of the patterns.
|
20
|
+
included_patterns: list[str] | None = None
|
21
|
+
|
22
|
+
# If provided, files matching these patterns will be excluded.
|
23
|
+
# See https://docs.rs/globset/latest/globset/index.html#syntax for the syntax of the patterns.
|
24
|
+
excluded_patterns: list[str] | None = None
|
25
|
+
|
26
|
+
|
27
|
+
class GoogleDrive(op.SourceSpec):
|
28
|
+
"""Import data from Google Drive."""
|
29
|
+
|
30
|
+
_op_category = op.OpCategory.SOURCE
|
31
|
+
|
32
|
+
service_account_credential_path: str
|
33
|
+
root_folder_ids: list[str]
|
34
|
+
binary: bool = False
|
35
|
+
recent_changes_poll_interval: datetime.timedelta | None = None
|
36
|
+
|
37
|
+
|
38
|
+
class AmazonS3(op.SourceSpec):
|
39
|
+
"""Import data from an Amazon S3 bucket. Supports optional prefix and file filtering by glob patterns."""
|
40
|
+
|
41
|
+
_op_category = op.OpCategory.SOURCE
|
42
|
+
|
43
|
+
bucket_name: str
|
44
|
+
prefix: str | None = None
|
45
|
+
binary: bool = False
|
46
|
+
included_patterns: list[str] | None = None
|
47
|
+
excluded_patterns: list[str] | None = None
|
48
|
+
sqs_queue_url: str | None = None
|
49
|
+
|
50
|
+
|
51
|
+
class AzureBlob(op.SourceSpec):
|
52
|
+
"""
|
53
|
+
Import data from an Azure Blob Storage container. Supports optional prefix and file filtering by glob patterns.
|
54
|
+
|
55
|
+
Authentication mechanisms taken in the following order:
|
56
|
+
- SAS token (if provided)
|
57
|
+
- Account access key (if provided)
|
58
|
+
- Default Azure credential
|
59
|
+
"""
|
60
|
+
|
61
|
+
_op_category = op.OpCategory.SOURCE
|
62
|
+
|
63
|
+
account_name: str
|
64
|
+
container_name: str
|
65
|
+
prefix: str | None = None
|
66
|
+
binary: bool = False
|
67
|
+
included_patterns: list[str] | None = None
|
68
|
+
excluded_patterns: list[str] | None = None
|
69
|
+
|
70
|
+
sas_token: TransientAuthEntryReference[str] | None = None
|
71
|
+
account_access_key: TransientAuthEntryReference[str] | None = None
|
72
|
+
|
73
|
+
|
74
|
+
@dataclass
|
75
|
+
class PostgresNotification:
|
76
|
+
"""Notification for a PostgreSQL table."""
|
77
|
+
|
78
|
+
# Optional: name of the PostgreSQL channel to use.
|
79
|
+
# If not provided, will generate a default channel name.
|
80
|
+
channel_name: str | None = None
|
81
|
+
|
82
|
+
|
83
|
+
class Postgres(op.SourceSpec):
|
84
|
+
"""Import data from a PostgreSQL table."""
|
85
|
+
|
86
|
+
_op_category = op.OpCategory.SOURCE
|
87
|
+
|
88
|
+
# Table name to read from (required)
|
89
|
+
table_name: str
|
90
|
+
|
91
|
+
# Database connection reference (optional - uses default if not provided)
|
92
|
+
database: TransientAuthEntryReference[DatabaseConnectionSpec] | None = None
|
93
|
+
|
94
|
+
# Optional: specific columns to include (if None, includes all columns)
|
95
|
+
included_columns: list[str] | None = None
|
96
|
+
|
97
|
+
# Optional: column name to use for ordinal tracking (for incremental updates)
|
98
|
+
# Should be a timestamp, serial, or other incrementing column
|
99
|
+
ordinal_column: str | None = None
|
100
|
+
|
101
|
+
# Optional: when set, supports change capture from PostgreSQL notification.
|
102
|
+
notification: PostgresNotification | None = None
|
@@ -0,0 +1,279 @@
|
|
1
|
+
"""
|
2
|
+
Lightweight subprocess-backed executor stub.
|
3
|
+
|
4
|
+
- Uses a single global ProcessPoolExecutor (max_workers=1), created lazily.
|
5
|
+
- In the subprocess, maintains a registry of executor instances keyed by
|
6
|
+
(executor_factory, pickled spec) to enable reuse.
|
7
|
+
- Caches analyze() and prepare() results per key to avoid repeated calls
|
8
|
+
even if key collision happens.
|
9
|
+
"""
|
10
|
+
|
11
|
+
from __future__ import annotations
|
12
|
+
|
13
|
+
from concurrent.futures import ProcessPoolExecutor
|
14
|
+
from concurrent.futures.process import BrokenProcessPool
|
15
|
+
from dataclasses import dataclass, field
|
16
|
+
from typing import Any, Callable
|
17
|
+
import pickle
|
18
|
+
import threading
|
19
|
+
import asyncio
|
20
|
+
import os
|
21
|
+
import time
|
22
|
+
import atexit
|
23
|
+
from .user_app_loader import load_user_app
|
24
|
+
from .runtime import execution_context
|
25
|
+
import logging
|
26
|
+
import multiprocessing as mp
|
27
|
+
|
28
|
+
WATCHDOG_INTERVAL_SECONDS = 10.0
|
29
|
+
|
30
|
+
# ---------------------------------------------
|
31
|
+
# Main process: single, lazily-created pool
|
32
|
+
# ---------------------------------------------
|
33
|
+
_pool_lock = threading.Lock()
|
34
|
+
_pool: ProcessPoolExecutor | None = None
|
35
|
+
_pool_cleanup_registered = False
|
36
|
+
_user_apps: list[str] = []
|
37
|
+
_logger = logging.getLogger(__name__)
|
38
|
+
|
39
|
+
|
40
|
+
def shutdown_pool_at_exit() -> None:
|
41
|
+
"""Best-effort shutdown of the global ProcessPoolExecutor on interpreter exit."""
|
42
|
+
global _pool, _pool_cleanup_registered # pylint: disable=global-statement
|
43
|
+
with _pool_lock:
|
44
|
+
if _pool is not None:
|
45
|
+
try:
|
46
|
+
_pool.shutdown(wait=True, cancel_futures=True)
|
47
|
+
except Exception as e:
|
48
|
+
_logger.error(
|
49
|
+
"Error during ProcessPoolExecutor shutdown at exit: %s",
|
50
|
+
e,
|
51
|
+
exc_info=True,
|
52
|
+
)
|
53
|
+
finally:
|
54
|
+
_pool = None
|
55
|
+
_pool_cleanup_registered = False
|
56
|
+
|
57
|
+
|
58
|
+
def _get_pool() -> ProcessPoolExecutor:
|
59
|
+
global _pool, _pool_cleanup_registered # pylint: disable=global-statement
|
60
|
+
with _pool_lock:
|
61
|
+
if _pool is None:
|
62
|
+
if not _pool_cleanup_registered:
|
63
|
+
# Register the shutdown at exit at creation time (rather than at import time)
|
64
|
+
# to make sure it's executed earlier in the shutdown sequence.
|
65
|
+
atexit.register(shutdown_pool_at_exit)
|
66
|
+
_pool_cleanup_registered = True
|
67
|
+
|
68
|
+
# Single worker process as requested
|
69
|
+
_pool = ProcessPoolExecutor(
|
70
|
+
max_workers=1,
|
71
|
+
initializer=_subprocess_init,
|
72
|
+
initargs=(_user_apps, os.getpid()),
|
73
|
+
mp_context=mp.get_context("spawn"),
|
74
|
+
)
|
75
|
+
return _pool
|
76
|
+
|
77
|
+
|
78
|
+
def add_user_app(app_target: str) -> None:
|
79
|
+
with _pool_lock:
|
80
|
+
_user_apps.append(app_target)
|
81
|
+
|
82
|
+
|
83
|
+
def _restart_pool(old_pool: ProcessPoolExecutor | None = None) -> None:
|
84
|
+
"""Safely restart the global ProcessPoolExecutor.
|
85
|
+
|
86
|
+
Thread-safe via `_pool_lock`. Shuts down the old pool and re-creates a new
|
87
|
+
one with the same initializer/args.
|
88
|
+
"""
|
89
|
+
global _pool
|
90
|
+
with _pool_lock:
|
91
|
+
# If another thread already swapped the pool, skip restart
|
92
|
+
if old_pool is not None and _pool is not old_pool:
|
93
|
+
return
|
94
|
+
_logger.error("Detected dead subprocess pool; restarting and retrying.")
|
95
|
+
prev_pool = _pool
|
96
|
+
_pool = ProcessPoolExecutor(
|
97
|
+
max_workers=1,
|
98
|
+
initializer=_subprocess_init,
|
99
|
+
initargs=(_user_apps, os.getpid()),
|
100
|
+
mp_context=mp.get_context("spawn"),
|
101
|
+
)
|
102
|
+
if prev_pool is not None:
|
103
|
+
# Best-effort shutdown of previous pool; letting exceptions bubble up
|
104
|
+
# is acceptable here and signals irrecoverable executor state.
|
105
|
+
prev_pool.shutdown(cancel_futures=True)
|
106
|
+
|
107
|
+
|
108
|
+
async def _submit_with_restart(fn: Callable[..., Any], *args: Any) -> Any:
|
109
|
+
"""Submit and await work, restarting the subprocess until it succeeds.
|
110
|
+
|
111
|
+
Retries on BrokenProcessPool or pool-shutdown RuntimeError; re-raises other
|
112
|
+
exceptions.
|
113
|
+
"""
|
114
|
+
while True:
|
115
|
+
pool = _get_pool()
|
116
|
+
try:
|
117
|
+
fut = pool.submit(fn, *args)
|
118
|
+
return await asyncio.wrap_future(fut)
|
119
|
+
except BrokenProcessPool:
|
120
|
+
_restart_pool(old_pool=pool)
|
121
|
+
# loop and retry
|
122
|
+
|
123
|
+
|
124
|
+
# ---------------------------------------------
|
125
|
+
# Subprocess: executor registry and helpers
|
126
|
+
# ---------------------------------------------
|
127
|
+
|
128
|
+
|
129
|
+
def _start_parent_watchdog(
|
130
|
+
parent_pid: int, interval_seconds: float = WATCHDOG_INTERVAL_SECONDS
|
131
|
+
) -> None:
|
132
|
+
"""Terminate this process if the parent process exits or PPID changes.
|
133
|
+
|
134
|
+
This runs in a background daemon thread so it never blocks pool work.
|
135
|
+
"""
|
136
|
+
|
137
|
+
def _watch() -> None:
|
138
|
+
while True:
|
139
|
+
# If PPID changed (parent died and we were reparented), exit.
|
140
|
+
if os.getppid() != parent_pid:
|
141
|
+
os._exit(1)
|
142
|
+
|
143
|
+
# Best-effort liveness probe in case PPID was reused.
|
144
|
+
try:
|
145
|
+
os.kill(parent_pid, 0)
|
146
|
+
except OSError:
|
147
|
+
os._exit(1)
|
148
|
+
|
149
|
+
time.sleep(interval_seconds)
|
150
|
+
|
151
|
+
threading.Thread(target=_watch, name="parent-watchdog", daemon=True).start()
|
152
|
+
|
153
|
+
|
154
|
+
def _subprocess_init(user_apps: list[str], parent_pid: int) -> None:
|
155
|
+
_start_parent_watchdog(parent_pid)
|
156
|
+
|
157
|
+
# In case any user app is already in this subprocess, e.g. the subprocess is forked, we need to avoid loading it again.
|
158
|
+
with _pool_lock:
|
159
|
+
already_loaded_apps = set(_user_apps)
|
160
|
+
|
161
|
+
loaded_apps = []
|
162
|
+
for app_target in user_apps:
|
163
|
+
if app_target not in already_loaded_apps:
|
164
|
+
load_user_app(app_target)
|
165
|
+
loaded_apps.append(app_target)
|
166
|
+
|
167
|
+
with _pool_lock:
|
168
|
+
_user_apps.extend(loaded_apps)
|
169
|
+
|
170
|
+
|
171
|
+
class _OnceResult:
|
172
|
+
_result: Any = None
|
173
|
+
_done: bool = False
|
174
|
+
|
175
|
+
def run_once(self, method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
176
|
+
if self._done:
|
177
|
+
return self._result
|
178
|
+
self._result = _call_method(method, *args, **kwargs)
|
179
|
+
self._done = True
|
180
|
+
return self._result
|
181
|
+
|
182
|
+
|
183
|
+
@dataclass
|
184
|
+
class _ExecutorEntry:
|
185
|
+
executor: Any
|
186
|
+
prepare: _OnceResult = field(default_factory=_OnceResult)
|
187
|
+
analyze: _OnceResult = field(default_factory=_OnceResult)
|
188
|
+
ready_to_call: bool = False
|
189
|
+
|
190
|
+
|
191
|
+
_SUBPROC_EXECUTORS: dict[bytes, _ExecutorEntry] = {}
|
192
|
+
|
193
|
+
|
194
|
+
def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
195
|
+
"""Run an awaitable/coroutine to completion synchronously, otherwise return as-is."""
|
196
|
+
if asyncio.iscoroutinefunction(method):
|
197
|
+
return asyncio.run(method(*args, **kwargs))
|
198
|
+
else:
|
199
|
+
return method(*args, **kwargs)
|
200
|
+
|
201
|
+
|
202
|
+
def _get_or_create_entry(key_bytes: bytes) -> _ExecutorEntry:
|
203
|
+
entry = _SUBPROC_EXECUTORS.get(key_bytes)
|
204
|
+
if entry is None:
|
205
|
+
executor_factory, spec = pickle.loads(key_bytes)
|
206
|
+
inst = executor_factory()
|
207
|
+
inst.spec = spec
|
208
|
+
entry = _ExecutorEntry(executor=inst)
|
209
|
+
_SUBPROC_EXECUTORS[key_bytes] = entry
|
210
|
+
return entry
|
211
|
+
|
212
|
+
|
213
|
+
def _sp_analyze(key_bytes: bytes) -> Any:
|
214
|
+
entry = _get_or_create_entry(key_bytes)
|
215
|
+
return entry.analyze.run_once(entry.executor.analyze)
|
216
|
+
|
217
|
+
|
218
|
+
def _sp_prepare(key_bytes: bytes) -> Any:
|
219
|
+
entry = _get_or_create_entry(key_bytes)
|
220
|
+
return entry.prepare.run_once(entry.executor.prepare)
|
221
|
+
|
222
|
+
|
223
|
+
def _sp_call(key_bytes: bytes, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
224
|
+
entry = _get_or_create_entry(key_bytes)
|
225
|
+
# There's a chance that the subprocess crashes and restarts in the middle.
|
226
|
+
# So we want to always make sure the executor is ready before each call.
|
227
|
+
if not entry.ready_to_call:
|
228
|
+
if analyze_fn := getattr(entry.executor, "analyze", None):
|
229
|
+
entry.analyze.run_once(analyze_fn)
|
230
|
+
if prepare_fn := getattr(entry.executor, "prepare", None):
|
231
|
+
entry.prepare.run_once(prepare_fn)
|
232
|
+
entry.ready_to_call = True
|
233
|
+
return _call_method(entry.executor.__call__, *args, **kwargs)
|
234
|
+
|
235
|
+
|
236
|
+
# ---------------------------------------------
|
237
|
+
# Public stub
|
238
|
+
# ---------------------------------------------
|
239
|
+
|
240
|
+
|
241
|
+
class _ExecutorStub:
|
242
|
+
_key_bytes: bytes
|
243
|
+
|
244
|
+
def __init__(self, executor_factory: type[Any], spec: Any) -> None:
|
245
|
+
self._key_bytes = pickle.dumps(
|
246
|
+
(executor_factory, spec), protocol=pickle.HIGHEST_PROTOCOL
|
247
|
+
)
|
248
|
+
|
249
|
+
# Conditionally expose analyze if underlying class has it
|
250
|
+
if hasattr(executor_factory, "analyze"):
|
251
|
+
# Bind as attribute so getattr(..., "analyze", None) works upstream
|
252
|
+
def analyze() -> Any:
|
253
|
+
return execution_context.run(
|
254
|
+
_submit_with_restart(_sp_analyze, self._key_bytes)
|
255
|
+
)
|
256
|
+
|
257
|
+
# Attach method
|
258
|
+
setattr(self, "analyze", analyze)
|
259
|
+
|
260
|
+
if hasattr(executor_factory, "prepare"):
|
261
|
+
|
262
|
+
async def prepare() -> Any:
|
263
|
+
return await _submit_with_restart(_sp_prepare, self._key_bytes)
|
264
|
+
|
265
|
+
setattr(self, "prepare", prepare)
|
266
|
+
|
267
|
+
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
268
|
+
return await _submit_with_restart(_sp_call, self._key_bytes, args, kwargs)
|
269
|
+
|
270
|
+
|
271
|
+
def executor_stub(executor_factory: type[Any], spec: Any) -> Any:
|
272
|
+
"""
|
273
|
+
Create a subprocess-backed stub for the given executor class/spec.
|
274
|
+
|
275
|
+
- Lazily initializes a singleton ProcessPoolExecutor (max_workers=1).
|
276
|
+
- Returns a stub object exposing async __call__ and async prepare; analyze is
|
277
|
+
exposed if present on the original class.
|
278
|
+
"""
|
279
|
+
return _ExecutorStub(executor_factory, spec)
|