python-plugin 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyplugin/__init__.py +54 -0
- pyplugin/_generated/__init__.py +0 -0
- pyplugin/_generated/grpc_broker_grpc.py +40 -0
- pyplugin/_generated/grpc_broker_pb2.py +41 -0
- pyplugin/_generated/grpc_controller_grpc.py +40 -0
- pyplugin/_generated/grpc_controller_pb2.py +39 -0
- pyplugin/_generated/grpc_stdio_grpc.py +41 -0
- pyplugin/_generated/grpc_stdio_pb2.py +42 -0
- pyplugin/broker.py +225 -0
- pyplugin/client.py +399 -0
- pyplugin/controller.py +22 -0
- pyplugin/cookie.py +43 -0
- pyplugin/errors.py +39 -0
- pyplugin/handshake.py +121 -0
- pyplugin/health.py +38 -0
- pyplugin/logging_bridge.py +70 -0
- pyplugin/mtls.py +169 -0
- pyplugin/plugin.py +38 -0
- pyplugin/process.py +66 -0
- pyplugin/proto/grpc_broker.proto +21 -0
- pyplugin/proto/grpc_controller.proto +12 -0
- pyplugin/proto/grpc_stdio.proto +22 -0
- pyplugin/reattach.py +27 -0
- pyplugin/server.py +204 -0
- pyplugin/stdio.py +36 -0
- pyplugin/transport.py +103 -0
- python_plugin-0.1.0.dist-info/METADATA +254 -0
- python_plugin-0.1.0.dist-info/RECORD +30 -0
- python_plugin-0.1.0.dist-info/WHEEL +4 -0
- python_plugin-0.1.0.dist-info/licenses/LICENSE +21 -0
pyplugin/client.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
"""Host-side ``Client`` — spawns a plugin, performs handshake, dispenses stubs (grpclib async).
|
|
2
|
+
|
|
3
|
+
API: an ``async`` Client. Use as ``async with Client(config) as c:`` and
|
|
4
|
+
``stub = c.dispense('name'); await stub.SomeMethod(req)`` — the dispensed
|
|
5
|
+
stubs are grpclib async stubs.
|
|
6
|
+
"""
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
import ssl
|
|
13
|
+
import subprocess
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from typing import Any, Mapping, Optional, Sequence, Union
|
|
16
|
+
|
|
17
|
+
from grpclib.client import Channel
|
|
18
|
+
from grpclib.config import Configuration
|
|
19
|
+
from grpclib.health.v1 import health_grpc, health_pb2
|
|
20
|
+
|
|
21
|
+
from . import logging_bridge, mtls, process, transport
|
|
22
|
+
from ._generated import grpc_controller_grpc, grpc_controller_pb2
|
|
23
|
+
from .broker import GRPCBroker, TLSMaterial, make_client_side_broker
|
|
24
|
+
from .errors import (
|
|
25
|
+
AppProtocolMismatch,
|
|
26
|
+
HandshakeError,
|
|
27
|
+
ProcessExitedError,
|
|
28
|
+
StartTimeout,
|
|
29
|
+
UnsupportedProtocol,
|
|
30
|
+
)
|
|
31
|
+
from .handshake import (
|
|
32
|
+
HandshakeConfig,
|
|
33
|
+
HandshakeLine,
|
|
34
|
+
PROTOCOL_GRPC,
|
|
35
|
+
parse_line,
|
|
36
|
+
)
|
|
37
|
+
from .plugin import Plugin, PluginSet, VersionedPlugins
|
|
38
|
+
from .reattach import ReattachConfig
|
|
39
|
+
from .server import ENV_CLIENT_CERT, ENV_PROTOCOL_VERSIONS, GRPC_HEALTH_SERVICE_NAME
|
|
40
|
+
from .transport import ENV_MAX_PORT, ENV_MIN_PORT
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class ClientConfig:
|
|
45
|
+
handshake_config: HandshakeConfig
|
|
46
|
+
plugins: Union[PluginSet, VersionedPlugins]
|
|
47
|
+
cmd: Optional[Sequence[str]] = None
|
|
48
|
+
reattach: Optional[ReattachConfig] = None
|
|
49
|
+
auto_mtls: bool = True
|
|
50
|
+
start_timeout: float = 60.0
|
|
51
|
+
kill_timeout: float = 2.0
|
|
52
|
+
logger: Optional[logging.Logger] = None
|
|
53
|
+
stderr_logger: Optional[logging.Logger] = None
|
|
54
|
+
env: Optional[Mapping[str, str]] = None
|
|
55
|
+
cwd: Optional[str] = None
|
|
56
|
+
skip_host_env: bool = False
|
|
57
|
+
min_port: int = 10000
|
|
58
|
+
max_port: int = 25000
|
|
59
|
+
grpc_options: list = field(default_factory=list)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _is_versioned(p: Union[PluginSet, VersionedPlugins]) -> bool:
|
|
63
|
+
if not p:
|
|
64
|
+
return False
|
|
65
|
+
return all(isinstance(k, int) for k in p.keys())
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class Client:
|
|
69
|
+
"""A handle on a running plugin subprocess (or reattached process)."""
|
|
70
|
+
|
|
71
|
+
def __init__(self, config: ClientConfig) -> None:
|
|
72
|
+
if (config.cmd is None) == (config.reattach is None):
|
|
73
|
+
raise ValueError("exactly one of `cmd` or `reattach` must be set")
|
|
74
|
+
self._cfg = config
|
|
75
|
+
self._logger = config.logger or logging.getLogger("pyplugin.client")
|
|
76
|
+
self._stderr_logger = config.stderr_logger or self._logger.getChild("stderr")
|
|
77
|
+
self._proc: Optional[subprocess.Popen] = None
|
|
78
|
+
self._handshake: Optional[HandshakeLine] = None
|
|
79
|
+
self._channel: Optional[Channel] = None
|
|
80
|
+
self._broker: Optional[GRPCBroker] = None
|
|
81
|
+
self._broker_task: Optional[asyncio.Task] = None
|
|
82
|
+
self._stderr_task: Optional[asyncio.Task] = None
|
|
83
|
+
self._negotiated_version: int = 0
|
|
84
|
+
self._plugin_set: PluginSet = {}
|
|
85
|
+
self._tls: Optional[dict[str, bytes]] = None
|
|
86
|
+
self._client_ssl: Optional[ssl.SSLContext] = None
|
|
87
|
+
self._killed = False
|
|
88
|
+
self._lock = asyncio.Lock()
|
|
89
|
+
|
|
90
|
+
# ----- public API -----
|
|
91
|
+
|
|
92
|
+
async def start(self) -> None:
|
|
93
|
+
async with self._lock:
|
|
94
|
+
if self._channel is not None:
|
|
95
|
+
return
|
|
96
|
+
if self._cfg.reattach is not None:
|
|
97
|
+
self._reattach()
|
|
98
|
+
else:
|
|
99
|
+
await self._spawn_and_handshake()
|
|
100
|
+
await self._dial()
|
|
101
|
+
|
|
102
|
+
def dispense(self, name: str) -> Any:
|
|
103
|
+
if self._channel is None:
|
|
104
|
+
raise RuntimeError("Client.start() must be awaited before dispense()")
|
|
105
|
+
if name not in self._plugin_set:
|
|
106
|
+
raise KeyError(f"unknown plugin: {name!r}")
|
|
107
|
+
plug = self._plugin_set[name]
|
|
108
|
+
return plug.stub(self._broker, self._channel) # type: ignore[arg-type]
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def broker(self) -> GRPCBroker:
|
|
112
|
+
if self._broker is None:
|
|
113
|
+
raise RuntimeError("Client.start() must be awaited before broker access")
|
|
114
|
+
return self._broker
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def negotiated_version(self) -> int:
|
|
118
|
+
return self._negotiated_version
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def pid(self) -> int | None:
|
|
122
|
+
if self._proc is not None:
|
|
123
|
+
return self._proc.pid
|
|
124
|
+
if self._cfg.reattach is not None:
|
|
125
|
+
return self._cfg.reattach.pid
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
def reattach_config(self) -> ReattachConfig | None:
|
|
129
|
+
if self._handshake is None:
|
|
130
|
+
return None
|
|
131
|
+
cert_b64 = self._handshake.server_cert
|
|
132
|
+
server_cert_pem = (
|
|
133
|
+
mtls.der_to_pem(mtls.decode_handshake_cert(cert_b64)) if cert_b64 else None
|
|
134
|
+
)
|
|
135
|
+
client_cert = self._tls["cert_pem"] if self._tls else None
|
|
136
|
+
client_key = self._tls["key_pem"] if self._tls else None
|
|
137
|
+
return ReattachConfig(
|
|
138
|
+
pid=self.pid or 0,
|
|
139
|
+
addr=self._handshake.address,
|
|
140
|
+
network=self._handshake.network,
|
|
141
|
+
protocol=self._handshake.protocol,
|
|
142
|
+
protocol_version=self._negotiated_version,
|
|
143
|
+
server_cert_pem=server_cert_pem,
|
|
144
|
+
client_cert_pem=client_cert,
|
|
145
|
+
client_key_pem=client_key,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
async def kill(self) -> None:
|
|
149
|
+
"""Walk the shutdown ladder: GRPCController.Shutdown → SIGTERM → SIGKILL."""
|
|
150
|
+
if self._killed:
|
|
151
|
+
return
|
|
152
|
+
self._killed = True
|
|
153
|
+
graceful = False
|
|
154
|
+
if self._channel is not None:
|
|
155
|
+
try:
|
|
156
|
+
stub = grpc_controller_grpc.GRPCControllerStub(self._channel)
|
|
157
|
+
await asyncio.wait_for(
|
|
158
|
+
stub.Shutdown(grpc_controller_pb2.Empty()),
|
|
159
|
+
timeout=2.0,
|
|
160
|
+
)
|
|
161
|
+
graceful = True
|
|
162
|
+
except (asyncio.TimeoutError, Exception) as e: # noqa: BLE001
|
|
163
|
+
self._logger.debug("controller.Shutdown failed: %s", e)
|
|
164
|
+
if self._broker is not None:
|
|
165
|
+
try:
|
|
166
|
+
await self._broker.close()
|
|
167
|
+
except Exception: # noqa: BLE001
|
|
168
|
+
pass
|
|
169
|
+
self._broker = None
|
|
170
|
+
if self._broker_task is not None:
|
|
171
|
+
self._broker_task.cancel()
|
|
172
|
+
if self._channel is not None:
|
|
173
|
+
try:
|
|
174
|
+
self._channel.close()
|
|
175
|
+
except Exception: # noqa: BLE001
|
|
176
|
+
pass
|
|
177
|
+
self._channel = None
|
|
178
|
+
|
|
179
|
+
if self._cfg.reattach is not None and self._cfg.reattach.test:
|
|
180
|
+
return
|
|
181
|
+
|
|
182
|
+
if self._proc is None:
|
|
183
|
+
return
|
|
184
|
+
|
|
185
|
+
loop = asyncio.get_event_loop()
|
|
186
|
+
deadline = loop.time() + self._cfg.kill_timeout
|
|
187
|
+
if graceful:
|
|
188
|
+
while loop.time() < deadline and self._proc.poll() is None:
|
|
189
|
+
await asyncio.sleep(0.05)
|
|
190
|
+
|
|
191
|
+
if self._proc.poll() is None:
|
|
192
|
+
self._logger.warning("plugin failed to exit gracefully — sending SIGTERM")
|
|
193
|
+
process.terminate(self._proc)
|
|
194
|
+
deadline = loop.time() + self._cfg.kill_timeout
|
|
195
|
+
while loop.time() < deadline and self._proc.poll() is None:
|
|
196
|
+
await asyncio.sleep(0.05)
|
|
197
|
+
|
|
198
|
+
if self._proc.poll() is None:
|
|
199
|
+
self._logger.error("plugin still alive after SIGTERM — sending SIGKILL")
|
|
200
|
+
process.kill(self._proc)
|
|
201
|
+
try:
|
|
202
|
+
await asyncio.get_event_loop().run_in_executor(None, lambda: self._proc.wait(timeout=2.0))
|
|
203
|
+
except subprocess.TimeoutExpired:
|
|
204
|
+
pass
|
|
205
|
+
|
|
206
|
+
if self._stderr_task is not None:
|
|
207
|
+
self._stderr_task.cancel()
|
|
208
|
+
|
|
209
|
+
async def __aenter__(self) -> "Client":
|
|
210
|
+
await self.start()
|
|
211
|
+
return self
|
|
212
|
+
|
|
213
|
+
async def __aexit__(self, exc_type, exc, tb) -> None:
|
|
214
|
+
await self.kill()
|
|
215
|
+
|
|
216
|
+
# ----- internals -----
|
|
217
|
+
|
|
218
|
+
def _build_env(self) -> dict[str, str]:
|
|
219
|
+
env: dict[str, str] = {}
|
|
220
|
+
if not self._cfg.skip_host_env:
|
|
221
|
+
env.update(os.environ)
|
|
222
|
+
if self._cfg.env:
|
|
223
|
+
env.update(self._cfg.env)
|
|
224
|
+
|
|
225
|
+
cookie = self._cfg.handshake_config
|
|
226
|
+
env[cookie.magic_cookie_key] = cookie.magic_cookie_value
|
|
227
|
+
|
|
228
|
+
if _is_versioned(self._cfg.plugins):
|
|
229
|
+
versions = sorted(self._cfg.plugins.keys()) # type: ignore[arg-type]
|
|
230
|
+
else:
|
|
231
|
+
versions = [cookie.protocol_version]
|
|
232
|
+
env[ENV_PROTOCOL_VERSIONS] = ",".join(str(v) for v in versions)
|
|
233
|
+
|
|
234
|
+
env[ENV_MIN_PORT] = str(self._cfg.min_port)
|
|
235
|
+
env[ENV_MAX_PORT] = str(self._cfg.max_port)
|
|
236
|
+
|
|
237
|
+
if self._cfg.auto_mtls:
|
|
238
|
+
host_cert = mtls.generate()
|
|
239
|
+
self._tls = {
|
|
240
|
+
"cert_pem": host_cert.cert_pem,
|
|
241
|
+
"key_pem": host_cert.key_pem,
|
|
242
|
+
"cert_der": host_cert.cert_der,
|
|
243
|
+
}
|
|
244
|
+
env[ENV_CLIENT_CERT] = host_cert.cert_pem.decode()
|
|
245
|
+
|
|
246
|
+
return env
|
|
247
|
+
|
|
248
|
+
async def _spawn_and_handshake(self) -> None:
|
|
249
|
+
env = self._build_env()
|
|
250
|
+
assert self._cfg.cmd is not None
|
|
251
|
+
self._proc = process.spawn(self._cfg.cmd, env=env, cwd=self._cfg.cwd)
|
|
252
|
+
|
|
253
|
+
# Async stderr forwarding.
|
|
254
|
+
loop = asyncio.get_running_loop()
|
|
255
|
+
stderr = self._proc.stderr
|
|
256
|
+
assert stderr is not None
|
|
257
|
+
self._stderr_task = loop.create_task(self._forward_stderr(stderr))
|
|
258
|
+
|
|
259
|
+
line = await self._read_handshake_line()
|
|
260
|
+
self._handshake = parse_line(line)
|
|
261
|
+
self._validate_handshake(self._handshake)
|
|
262
|
+
|
|
263
|
+
async def _read_handshake_line(self) -> str:
|
|
264
|
+
assert self._proc is not None
|
|
265
|
+
stdout = self._proc.stdout
|
|
266
|
+
assert stdout is not None
|
|
267
|
+
loop = asyncio.get_running_loop()
|
|
268
|
+
try:
|
|
269
|
+
raw = await asyncio.wait_for(
|
|
270
|
+
loop.run_in_executor(None, stdout.readline),
|
|
271
|
+
timeout=self._cfg.start_timeout,
|
|
272
|
+
)
|
|
273
|
+
except asyncio.TimeoutError:
|
|
274
|
+
self._proc.kill()
|
|
275
|
+
raise StartTimeout(
|
|
276
|
+
f"plugin did not emit a handshake within {self._cfg.start_timeout}s"
|
|
277
|
+
)
|
|
278
|
+
if not raw:
|
|
279
|
+
raise ProcessExitedError("plugin exited before sending handshake")
|
|
280
|
+
return raw.decode("utf-8", errors="replace").strip()
|
|
281
|
+
|
|
282
|
+
def _validate_handshake(self, h: HandshakeLine) -> None:
|
|
283
|
+
if _is_versioned(self._cfg.plugins):
|
|
284
|
+
versioned: dict[int, PluginSet] = self._cfg.plugins # type: ignore[assignment]
|
|
285
|
+
if h.app_protocol_version not in versioned:
|
|
286
|
+
raise AppProtocolMismatch(
|
|
287
|
+
f"plugin advertised version {h.app_protocol_version}; "
|
|
288
|
+
f"client supports {sorted(versioned.keys())}"
|
|
289
|
+
)
|
|
290
|
+
self._plugin_set = versioned[h.app_protocol_version]
|
|
291
|
+
else:
|
|
292
|
+
cfg_v = self._cfg.handshake_config.protocol_version
|
|
293
|
+
if h.app_protocol_version != cfg_v:
|
|
294
|
+
raise AppProtocolMismatch(
|
|
295
|
+
f"plugin advertised version {h.app_protocol_version}; "
|
|
296
|
+
f"client expects {cfg_v}"
|
|
297
|
+
)
|
|
298
|
+
self._plugin_set = self._cfg.plugins # type: ignore[assignment]
|
|
299
|
+
|
|
300
|
+
if h.protocol != PROTOCOL_GRPC:
|
|
301
|
+
raise UnsupportedProtocol(
|
|
302
|
+
f"plugin advertised protocol {h.protocol!r}; pyplugin only supports 'grpc'")
|
|
303
|
+
|
|
304
|
+
self._negotiated_version = h.app_protocol_version
|
|
305
|
+
|
|
306
|
+
def _reattach(self) -> None:
|
|
307
|
+
r = self._cfg.reattach
|
|
308
|
+
assert r is not None
|
|
309
|
+
cert_b64 = ""
|
|
310
|
+
if r.server_cert_pem is not None:
|
|
311
|
+
from cryptography import x509
|
|
312
|
+
from cryptography.hazmat.primitives import serialization
|
|
313
|
+
cert = x509.load_pem_x509_certificate(r.server_cert_pem)
|
|
314
|
+
cert_b64 = mtls.encode_handshake_cert(cert.public_bytes(serialization.Encoding.DER))
|
|
315
|
+
self._handshake = HandshakeLine(
|
|
316
|
+
core_protocol_version=1,
|
|
317
|
+
app_protocol_version=r.protocol_version,
|
|
318
|
+
network=r.network,
|
|
319
|
+
address=r.addr,
|
|
320
|
+
protocol=r.protocol,
|
|
321
|
+
server_cert=cert_b64,
|
|
322
|
+
)
|
|
323
|
+
self._negotiated_version = r.protocol_version
|
|
324
|
+
if _is_versioned(self._cfg.plugins):
|
|
325
|
+
versioned: dict[int, PluginSet] = self._cfg.plugins # type: ignore[assignment]
|
|
326
|
+
self._plugin_set = versioned.get(r.protocol_version, {})
|
|
327
|
+
else:
|
|
328
|
+
self._plugin_set = self._cfg.plugins # type: ignore[assignment]
|
|
329
|
+
|
|
330
|
+
if r.client_cert_pem and r.client_key_pem:
|
|
331
|
+
self._tls = {
|
|
332
|
+
"cert_pem": r.client_cert_pem,
|
|
333
|
+
"key_pem": r.client_key_pem,
|
|
334
|
+
"cert_der": b"",
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
async def _dial(self) -> None:
|
|
338
|
+
h = self._handshake
|
|
339
|
+
assert h is not None
|
|
340
|
+
|
|
341
|
+
if h.server_cert and self._tls is not None:
|
|
342
|
+
server_cert_der = mtls.decode_handshake_cert(h.server_cert)
|
|
343
|
+
server_cert_pem = mtls.der_to_pem(server_cert_der)
|
|
344
|
+
self._client_ssl = mtls.client_ssl_context(
|
|
345
|
+
cert_pem=self._tls["cert_pem"],
|
|
346
|
+
key_pem=self._tls["key_pem"],
|
|
347
|
+
peer_cert_pem=server_cert_pem,
|
|
348
|
+
)
|
|
349
|
+
elif h.server_cert and self._tls is None:
|
|
350
|
+
raise HandshakeError(
|
|
351
|
+
"plugin advertised AutoMTLS but client wasn't configured for it")
|
|
352
|
+
|
|
353
|
+
cfg = Configuration(ssl_target_name_override="localhost") if self._client_ssl else None
|
|
354
|
+
if h.network == "unix":
|
|
355
|
+
self._channel = Channel(path=h.address, ssl=self._client_ssl, config=cfg)
|
|
356
|
+
else:
|
|
357
|
+
host, port = h.address.split(":")
|
|
358
|
+
self._channel = Channel(host=host, port=int(port), ssl=self._client_ssl, config=cfg)
|
|
359
|
+
|
|
360
|
+
# Health check — match go-plugin's ping (Service: "plugin").
|
|
361
|
+
loop = asyncio.get_running_loop()
|
|
362
|
+
deadline = loop.time() + min(self._cfg.start_timeout, 30.0)
|
|
363
|
+
last_err: Exception | None = None
|
|
364
|
+
while loop.time() < deadline:
|
|
365
|
+
try:
|
|
366
|
+
hstub = health_grpc.HealthStub(self._channel)
|
|
367
|
+
resp = await asyncio.wait_for(
|
|
368
|
+
hstub.Check(health_pb2.HealthCheckRequest(service=GRPC_HEALTH_SERVICE_NAME)),
|
|
369
|
+
timeout=2.0,
|
|
370
|
+
)
|
|
371
|
+
if resp.status == health_pb2.HealthCheckResponse.SERVING:
|
|
372
|
+
last_err = None
|
|
373
|
+
break
|
|
374
|
+
last_err = HandshakeError(f"plugin health = {resp.status}")
|
|
375
|
+
except Exception as e: # noqa: BLE001
|
|
376
|
+
last_err = e
|
|
377
|
+
await asyncio.sleep(0.05)
|
|
378
|
+
if last_err is not None:
|
|
379
|
+
raise HandshakeError(f"plugin health check failed: {last_err}")
|
|
380
|
+
|
|
381
|
+
# Start broker stream.
|
|
382
|
+
broker_tls: TLSMaterial | None = None
|
|
383
|
+
if self._client_ssl is not None and self._tls is not None and h.server_cert:
|
|
384
|
+
broker_tls = TLSMaterial(
|
|
385
|
+
cert_pem=self._tls["cert_pem"],
|
|
386
|
+
key_pem=self._tls["key_pem"],
|
|
387
|
+
peer_cert_pem=mtls.der_to_pem(mtls.decode_handshake_cert(h.server_cert)),
|
|
388
|
+
)
|
|
389
|
+
self._broker, self._broker_task = make_client_side_broker(self._channel, broker_tls)
|
|
390
|
+
|
|
391
|
+
async def _forward_stderr(self, stream) -> None:
|
|
392
|
+
loop = asyncio.get_running_loop()
|
|
393
|
+
while True:
|
|
394
|
+
raw = await loop.run_in_executor(None, stream.readline)
|
|
395
|
+
if not raw:
|
|
396
|
+
return
|
|
397
|
+
line = raw.decode("utf-8", errors="replace").rstrip()
|
|
398
|
+
if line:
|
|
399
|
+
logging_bridge.emit(self._stderr_logger, line)
|
pyplugin/controller.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""GRPCController.Shutdown servicer (grpclib async)."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import asyncio
|
|
5
|
+
|
|
6
|
+
from ._generated import grpc_controller_grpc, grpc_controller_pb2
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class GRPCControllerServicer(grpc_controller_grpc.GRPCControllerBase):
|
|
10
|
+
"""Sets a shutdown ``asyncio.Event`` when ``Shutdown`` is invoked.
|
|
11
|
+
|
|
12
|
+
Mirrors go-plugin's grpcControllerServer: ``Shutdown`` returns immediately
|
|
13
|
+
with Empty; the serve loop awaits the event, then closes the gRPC server.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self) -> None:
|
|
17
|
+
self.shutdown_event = asyncio.Event()
|
|
18
|
+
|
|
19
|
+
async def Shutdown(self, stream) -> None: # noqa: N802
|
|
20
|
+
await stream.recv_message()
|
|
21
|
+
await stream.send_message(grpc_controller_pb2.Empty())
|
|
22
|
+
self.shutdown_event.set()
|
pyplugin/cookie.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Magic-cookie validation. UX feature, not a security boundary.
|
|
2
|
+
|
|
3
|
+
If the cookie doesn't match, we print the same human-friendly message
|
|
4
|
+
go-plugin uses (so users who run a plugin binary directly get a useful hint)
|
|
5
|
+
and exit 1.
|
|
6
|
+
"""
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
|
|
12
|
+
from .handshake import HandshakeConfig
|
|
13
|
+
|
|
14
|
+
_NOT_A_CLI_MESSAGE = (
|
|
15
|
+
"This binary is a plugin. These are not meant to be executed directly.\n"
|
|
16
|
+
"Please execute the program that consumes these plugins, which will\n"
|
|
17
|
+
"load any plugins automatically\n"
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
_MISCONFIG_MESSAGE = (
|
|
21
|
+
"Misconfigured ServeConfig given to serve this plugin: no magic cookie\n"
|
|
22
|
+
"key or value was set. Please notify the plugin author and report\n"
|
|
23
|
+
"this as a bug.\n"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def validate_or_exit(config: HandshakeConfig, env: os._Environ[str] | None = None) -> None:
|
|
28
|
+
"""Verify the magic cookie env var matches; otherwise print a friendly
|
|
29
|
+
message and ``sys.exit(1)``.
|
|
30
|
+
|
|
31
|
+
This is the first thing a plugin's ``serve()`` does. Run it before
|
|
32
|
+
importing heavy deps so the user-facing message stays clean.
|
|
33
|
+
"""
|
|
34
|
+
if env is None:
|
|
35
|
+
env = os.environ
|
|
36
|
+
|
|
37
|
+
if not config.magic_cookie_key or not config.magic_cookie_value:
|
|
38
|
+
sys.stderr.write(_MISCONFIG_MESSAGE)
|
|
39
|
+
sys.exit(1)
|
|
40
|
+
|
|
41
|
+
if env.get(config.magic_cookie_key) != config.magic_cookie_value:
|
|
42
|
+
sys.stderr.write(_NOT_A_CLI_MESSAGE)
|
|
43
|
+
sys.exit(1)
|
pyplugin/errors.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Public exception hierarchy for pyplugin."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class PyPluginError(Exception):
|
|
6
|
+
"""Base for all pyplugin errors."""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class HandshakeError(PyPluginError):
|
|
10
|
+
"""Plugin's stdout handshake line was malformed or missing."""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CoreProtocolMismatch(HandshakeError):
|
|
14
|
+
"""Plugin advertised a core protocol version we don't speak."""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AppProtocolMismatch(HandshakeError):
|
|
18
|
+
"""Plugin advertised an app protocol version not in our VersionedPlugins."""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class UnsupportedProtocol(HandshakeError):
|
|
22
|
+
"""Plugin advertised a wire protocol other than 'grpc' (e.g. 'netrpc')."""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MagicCookieMismatch(PyPluginError):
|
|
26
|
+
"""Plugin's magic cookie env var didn't match. Plugin author bug or a user
|
|
27
|
+
invoking the plugin binary directly."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ProcessExitedError(PyPluginError):
|
|
31
|
+
"""The plugin subprocess exited before we could connect or during a call."""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TLSError(PyPluginError):
|
|
35
|
+
"""AutoMTLS setup or verification failed."""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class StartTimeout(PyPluginError):
|
|
39
|
+
"""Plugin did not emit a handshake within the configured start_timeout."""
|
pyplugin/handshake.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""go-plugin stdout handshake protocol.
|
|
2
|
+
|
|
3
|
+
Wire format (single line, ``\\n`` terminated, written by plugin to stdout)::
|
|
4
|
+
|
|
5
|
+
CORE | APP | NETWORK | ADDRESS | PROTOCOL | SERVER-CERT [ | MULTIPLEX ]
|
|
6
|
+
|
|
7
|
+
* CORE: int. Always ``1`` (CoreProtocolVersion). Anything else is a hard reject.
|
|
8
|
+
* APP: int. Application-defined protocol version.
|
|
9
|
+
* NETWORK: ``unix`` or ``tcp``.
|
|
10
|
+
* ADDRESS: socket path or ``ip:port``.
|
|
11
|
+
* PROTOCOL: ``grpc`` (we don't implement Go's net/rpc).
|
|
12
|
+
* SERVER-CERT: empty when AutoMTLS is off, else base64.RawStdEncoding of the
|
|
13
|
+
server's leaf cert in raw DER. Length > 50 distinguishes a real cert from
|
|
14
|
+
legacy "extra" data — that's how the Go client sniffs it.
|
|
15
|
+
* MULTIPLEX: optional 7th field, ``true`` when GRPCBrokerMultiplex is supported.
|
|
16
|
+
|
|
17
|
+
Format mirrors `hashicorp/go-plugin server.go::Serve` exactly.
|
|
18
|
+
"""
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
|
|
23
|
+
from .errors import CoreProtocolMismatch, HandshakeError
|
|
24
|
+
|
|
25
|
+
CORE_PROTOCOL_VERSION = 1
|
|
26
|
+
PROTOCOL_GRPC = "grpc"
|
|
27
|
+
PROTOCOL_NETRPC = "netrpc"
|
|
28
|
+
|
|
29
|
+
NETWORK_UNIX = "unix"
|
|
30
|
+
NETWORK_TCP = "tcp"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(frozen=True)
|
|
34
|
+
class HandshakeConfig:
|
|
35
|
+
"""Plugin/host handshake config. Mirrors go-plugin's HandshakeConfig.
|
|
36
|
+
|
|
37
|
+
Both sides must agree on these for a plugin to load.
|
|
38
|
+
"""
|
|
39
|
+
protocol_version: int
|
|
40
|
+
magic_cookie_key: str
|
|
41
|
+
magic_cookie_value: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass(frozen=True)
|
|
45
|
+
class HandshakeLine:
|
|
46
|
+
"""Parsed handshake line."""
|
|
47
|
+
core_protocol_version: int
|
|
48
|
+
app_protocol_version: int
|
|
49
|
+
network: str
|
|
50
|
+
address: str
|
|
51
|
+
protocol: str
|
|
52
|
+
server_cert: str = ""
|
|
53
|
+
multiplex_supported: bool = False
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def format_line(
|
|
57
|
+
*,
|
|
58
|
+
app_protocol_version: int,
|
|
59
|
+
network: str,
|
|
60
|
+
address: str,
|
|
61
|
+
protocol: str = PROTOCOL_GRPC,
|
|
62
|
+
server_cert: str = "",
|
|
63
|
+
multiplex_supported: bool | None = None,
|
|
64
|
+
) -> str:
|
|
65
|
+
"""Build the handshake line a plugin should write to stdout.
|
|
66
|
+
|
|
67
|
+
Always emits 6 segments (matching go-plugin); appends a 7th only when
|
|
68
|
+
multiplex_supported is not None (i.e. the env opted in).
|
|
69
|
+
"""
|
|
70
|
+
line = f"{CORE_PROTOCOL_VERSION}|{app_protocol_version}|{network}|{address}|{protocol}|{server_cert}"
|
|
71
|
+
if multiplex_supported is not None:
|
|
72
|
+
line += f"|{'true' if multiplex_supported else 'false'}"
|
|
73
|
+
return line
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def parse_line(raw: str) -> HandshakeLine:
|
|
77
|
+
"""Parse a handshake line. Tolerant of 4..7 segments to match go-plugin."""
|
|
78
|
+
line = raw.strip()
|
|
79
|
+
parts = line.split("|")
|
|
80
|
+
if len(parts) < 4:
|
|
81
|
+
raise HandshakeError(f"unrecognized remote plugin message: {raw!r}")
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
core = int(parts[0])
|
|
85
|
+
except ValueError as e:
|
|
86
|
+
raise HandshakeError(f"error parsing core protocol version: {e}") from e
|
|
87
|
+
if core != CORE_PROTOCOL_VERSION:
|
|
88
|
+
raise CoreProtocolMismatch(
|
|
89
|
+
f"incompatible core API version with plugin. "
|
|
90
|
+
f"Plugin version: {parts[0]}, Core version: {CORE_PROTOCOL_VERSION}"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
app = int(parts[1])
|
|
95
|
+
except ValueError as e:
|
|
96
|
+
raise HandshakeError(f"error parsing app protocol version: {e}") from e
|
|
97
|
+
|
|
98
|
+
network = parts[2]
|
|
99
|
+
address = parts[3]
|
|
100
|
+
# Default to netrpc for backward compat with very old Go plugins, mirroring client.go.
|
|
101
|
+
protocol = parts[4] if len(parts) >= 5 else PROTOCOL_NETRPC
|
|
102
|
+
|
|
103
|
+
# The Go client uses len(parts[5]) > 50 to decide if a real cert is present
|
|
104
|
+
# (older plugins emit some unrelated "extra" data here).
|
|
105
|
+
server_cert = ""
|
|
106
|
+
if len(parts) >= 6 and len(parts[5]) > 50:
|
|
107
|
+
server_cert = parts[5]
|
|
108
|
+
|
|
109
|
+
multiplex = False
|
|
110
|
+
if len(parts) >= 7:
|
|
111
|
+
multiplex = parts[6].lower() == "true"
|
|
112
|
+
|
|
113
|
+
return HandshakeLine(
|
|
114
|
+
core_protocol_version=core,
|
|
115
|
+
app_protocol_version=app,
|
|
116
|
+
network=network,
|
|
117
|
+
address=address,
|
|
118
|
+
protocol=protocol,
|
|
119
|
+
server_cert=server_cert,
|
|
120
|
+
multiplex_supported=multiplex,
|
|
121
|
+
)
|
pyplugin/health.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Minimal grpc.health.v1 servicer that returns SERVING for go-plugin's "plugin" name.
|
|
2
|
+
|
|
3
|
+
go-plugin's host pings ``Check(service="plugin")`` and expects ``SERVING``.
|
|
4
|
+
grpclib ships a more elaborate Health service that derives names from registered
|
|
5
|
+
``IServable`` objects' method mappings — we just need a single static name
|
|
6
|
+
matching go-plugin's wire convention.
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from grpclib.const import Status
|
|
11
|
+
from grpclib.health.v1 import health_grpc, health_pb2
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class StaticHealth(health_grpc.HealthBase):
|
|
15
|
+
"""Returns SERVING for any registered service name; NOT_FOUND otherwise."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, services: list[str] | None = None) -> None:
|
|
18
|
+
# "" is the "overall" service name; "plugin" is what go-plugin pings.
|
|
19
|
+
self._services = set(services or ["", "plugin"])
|
|
20
|
+
|
|
21
|
+
async def Check(self, stream) -> None: # noqa: N802
|
|
22
|
+
request = await stream.recv_message()
|
|
23
|
+
if request.service in self._services:
|
|
24
|
+
await stream.send_message(health_pb2.HealthCheckResponse(
|
|
25
|
+
status=health_pb2.HealthCheckResponse.SERVING,
|
|
26
|
+
))
|
|
27
|
+
else:
|
|
28
|
+
await stream.send_trailing_metadata(status=Status.NOT_FOUND)
|
|
29
|
+
|
|
30
|
+
async def Watch(self, stream) -> None: # noqa: N802
|
|
31
|
+
# Not used by go-plugin; emit a single response and close.
|
|
32
|
+
request = await stream.recv_message()
|
|
33
|
+
if request.service in self._services:
|
|
34
|
+
await stream.send_message(health_pb2.HealthCheckResponse(
|
|
35
|
+
status=health_pb2.HealthCheckResponse.SERVING,
|
|
36
|
+
))
|
|
37
|
+
else:
|
|
38
|
+
await stream.send_trailing_metadata(status=Status.NOT_FOUND)
|