arbiter-server 0.9.1.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbiter_server/__init__.py +1 -0
- arbiter_server/__main__.py +5 -0
- arbiter_server/app.py +44 -0
- arbiter_server/artifacts.py +344 -0
- arbiter_server/cli_errors.py +31 -0
- arbiter_server/config.py +231 -0
- arbiter_server/deploy/docker/arbiter-docker +4477 -0
- arbiter_server/deploy/docker/compose.yaml +101 -0
- arbiter_server/file_protection/__init__.py +20 -0
- arbiter_server/file_protection/posix.py +70 -0
- arbiter_server/file_protection/windows.py +379 -0
- arbiter_server/main.py +2843 -0
- arbiter_server/plugins/__init__.py +36 -0
- arbiter_server/py.typed +1 -0
- arbiter_server/services.py +706 -0
- arbiter_server/storage.py +60 -0
- arbiter_server/version.py +135 -0
- arbiter_server-0.9.1.dev1.dist-info/METADATA +26 -0
- arbiter_server-0.9.1.dev1.dist-info/RECORD +22 -0
- arbiter_server-0.9.1.dev1.dist-info/WHEEL +5 -0
- arbiter_server-0.9.1.dev1.dist-info/entry_points.txt +2 -0
- arbiter_server-0.9.1.dev1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Arbiter server package."""
|
arbiter_server/app.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Protocol, cast
|
|
4
|
+
|
|
5
|
+
from .services import RuntimeRegistry
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
SERVER_TOOL_NAMES = (
|
|
9
|
+
"info",
|
|
10
|
+
"version_info",
|
|
11
|
+
"list_caps",
|
|
12
|
+
"describe_caps",
|
|
13
|
+
"describe_cap",
|
|
14
|
+
"describe_op",
|
|
15
|
+
"run_op",
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AccountSummariesRuntime(Protocol):
|
|
20
|
+
def account_summaries(self) -> dict[str, object]: ...
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ArbiterApp:
|
|
24
|
+
"""Server facade over entry-point supplied service runtimes."""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
runtime_registry: RuntimeRegistry,
|
|
29
|
+
) -> None:
|
|
30
|
+
self.runtime_registry = runtime_registry
|
|
31
|
+
|
|
32
|
+
def tool_names(self) -> list[str]:
|
|
33
|
+
return list(SERVER_TOOL_NAMES)
|
|
34
|
+
|
|
35
|
+
def list_accounts(self) -> dict[str, object]:
|
|
36
|
+
summaries: dict[str, object] = {}
|
|
37
|
+
for service_name, runtime in sorted(self.runtime_registry.items()):
|
|
38
|
+
if not hasattr(runtime, "account_summaries"):
|
|
39
|
+
continue
|
|
40
|
+
summaries[service_name] = cast(
|
|
41
|
+
AccountSummariesRuntime,
|
|
42
|
+
runtime,
|
|
43
|
+
).account_summaries()
|
|
44
|
+
return summaries
|
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
import hashlib
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
import re
|
|
10
|
+
import secrets
|
|
11
|
+
import shutil
|
|
12
|
+
import time
|
|
13
|
+
from urllib.parse import quote
|
|
14
|
+
|
|
15
|
+
from .storage import ensure_private_dir, plugin_data_dir
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
DEFAULT_IDLE_TTL_SECONDS = 10 * 60
|
|
19
|
+
DEFAULT_RETENTION_SECONDS = 24 * 60 * 60
|
|
20
|
+
ARTIFACT_ID_PATTERN = re.compile(r"^[A-Za-z0-9_-]+$")
|
|
21
|
+
CONSUMED_MARKER = "consumed"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ArtifactError(RuntimeError):
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ArtifactNotFound(ArtifactError):
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ArtifactExpired(ArtifactError):
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ArtifactConsumed(ArtifactError):
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(frozen=True)
|
|
41
|
+
class ArtifactDescriptor:
|
|
42
|
+
id: str
|
|
43
|
+
url: str
|
|
44
|
+
filename: str | None
|
|
45
|
+
content_type: str
|
|
46
|
+
size: int
|
|
47
|
+
sha256: str
|
|
48
|
+
created_at: str
|
|
49
|
+
expires_after_idle_seconds: int
|
|
50
|
+
one_time: bool
|
|
51
|
+
|
|
52
|
+
def to_dict(self) -> dict[str, object]:
|
|
53
|
+
return {
|
|
54
|
+
"id": self.id,
|
|
55
|
+
"url": self.url,
|
|
56
|
+
"filename": self.filename,
|
|
57
|
+
"content_type": self.content_type,
|
|
58
|
+
"size": self.size,
|
|
59
|
+
"sha256": self.sha256,
|
|
60
|
+
"created_at": self.created_at,
|
|
61
|
+
"expires_after_idle_seconds": self.expires_after_idle_seconds,
|
|
62
|
+
"one_time": self.one_time,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass(frozen=True)
|
|
67
|
+
class ArtifactRead:
|
|
68
|
+
path: Path
|
|
69
|
+
filename: str | None
|
|
70
|
+
content_type: str
|
|
71
|
+
size: int
|
|
72
|
+
sha256: str
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class PluginArtifactStore:
|
|
76
|
+
def __init__(self, *, plugin: str, store: ArtifactStore) -> None:
|
|
77
|
+
self._plugin = plugin
|
|
78
|
+
self._store = store
|
|
79
|
+
|
|
80
|
+
def create(
|
|
81
|
+
self,
|
|
82
|
+
*,
|
|
83
|
+
content: bytes,
|
|
84
|
+
filename: str | None,
|
|
85
|
+
content_type: str,
|
|
86
|
+
source: dict[str, object],
|
|
87
|
+
) -> ArtifactDescriptor:
|
|
88
|
+
return self._store.create(
|
|
89
|
+
plugin=self._plugin,
|
|
90
|
+
content=content,
|
|
91
|
+
filename=filename,
|
|
92
|
+
content_type=content_type,
|
|
93
|
+
source=source,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class ArtifactStore:
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
*,
|
|
101
|
+
root: Path,
|
|
102
|
+
base_url: str,
|
|
103
|
+
idle_ttl_seconds: int = DEFAULT_IDLE_TTL_SECONDS,
|
|
104
|
+
retention_seconds: int = DEFAULT_RETENTION_SECONDS,
|
|
105
|
+
) -> None:
|
|
106
|
+
if idle_ttl_seconds < 1:
|
|
107
|
+
raise ValueError("artifact idle_ttl_seconds must be at least 1")
|
|
108
|
+
if retention_seconds < 1:
|
|
109
|
+
raise ValueError("artifact retention_seconds must be at least 1")
|
|
110
|
+
self._root = root
|
|
111
|
+
self._base_url = base_url.rstrip("/")
|
|
112
|
+
self._idle_ttl_seconds = idle_ttl_seconds
|
|
113
|
+
self._retention_seconds = retention_seconds
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def idle_ttl_seconds(self) -> int:
|
|
117
|
+
return self._idle_ttl_seconds
|
|
118
|
+
|
|
119
|
+
def for_plugin(self, plugin: str) -> PluginArtifactStore:
|
|
120
|
+
return PluginArtifactStore(plugin=plugin, store=self)
|
|
121
|
+
|
|
122
|
+
def create(
|
|
123
|
+
self,
|
|
124
|
+
*,
|
|
125
|
+
plugin: str,
|
|
126
|
+
content: bytes,
|
|
127
|
+
filename: str | None,
|
|
128
|
+
content_type: str,
|
|
129
|
+
source: dict[str, object],
|
|
130
|
+
) -> ArtifactDescriptor:
|
|
131
|
+
self.purge_expired()
|
|
132
|
+
artifact_id = secrets.token_urlsafe(24)
|
|
133
|
+
nonce = secrets.token_urlsafe(32)
|
|
134
|
+
created_at = time.time()
|
|
135
|
+
payload_hash = hashlib.sha256(content).hexdigest()
|
|
136
|
+
artifact_dir = plugin_data_dir(self._root, plugin) / "artifacts" / artifact_id
|
|
137
|
+
ensure_private_dir(artifact_dir)
|
|
138
|
+
payload_path = artifact_dir / "payload"
|
|
139
|
+
metadata_path = artifact_dir / "metadata.json"
|
|
140
|
+
_write_private_file(payload_path, content)
|
|
141
|
+
metadata = {
|
|
142
|
+
"id": artifact_id,
|
|
143
|
+
"plugin": plugin,
|
|
144
|
+
"filename": filename,
|
|
145
|
+
"content_type": content_type,
|
|
146
|
+
"size": len(content),
|
|
147
|
+
"sha256": payload_hash,
|
|
148
|
+
"source": source,
|
|
149
|
+
"nonce_sha256": hashlib.sha256(nonce.encode("utf-8")).hexdigest(),
|
|
150
|
+
"created_at": created_at,
|
|
151
|
+
"created_at_iso": _iso_timestamp(created_at),
|
|
152
|
+
"last_accessed_at": None,
|
|
153
|
+
"last_accessed_at_iso": None,
|
|
154
|
+
"access_count": 0,
|
|
155
|
+
"consumed": False,
|
|
156
|
+
"one_time": True,
|
|
157
|
+
"idle_ttl_seconds": self._idle_ttl_seconds,
|
|
158
|
+
"retention_seconds": self._retention_seconds,
|
|
159
|
+
}
|
|
160
|
+
_write_private_file(
|
|
161
|
+
metadata_path,
|
|
162
|
+
json.dumps(metadata, sort_keys=True).encode("utf-8"),
|
|
163
|
+
)
|
|
164
|
+
encoded_id = quote(artifact_id, safe="")
|
|
165
|
+
encoded_nonce = quote(nonce, safe="")
|
|
166
|
+
return ArtifactDescriptor(
|
|
167
|
+
id=artifact_id,
|
|
168
|
+
url=f"{self._base_url}/{encoded_id}?nonce={encoded_nonce}",
|
|
169
|
+
filename=filename,
|
|
170
|
+
content_type=content_type,
|
|
171
|
+
size=len(content),
|
|
172
|
+
sha256=payload_hash,
|
|
173
|
+
created_at=str(metadata["created_at_iso"]),
|
|
174
|
+
expires_after_idle_seconds=self._idle_ttl_seconds,
|
|
175
|
+
one_time=True,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def open_once(self, artifact_id: str, nonce: str) -> ArtifactRead:
|
|
179
|
+
self.purge_expired()
|
|
180
|
+
artifact_dir, metadata = self._validated_artifact(artifact_id, nonce)
|
|
181
|
+
self._claim_artifact(artifact_dir, artifact_id)
|
|
182
|
+
self._record_access(artifact_dir, metadata, consume=True)
|
|
183
|
+
return _artifact_read(artifact_dir, metadata)
|
|
184
|
+
|
|
185
|
+
def inspect(self, artifact_id: str, nonce: str) -> ArtifactRead:
|
|
186
|
+
self.purge_expired()
|
|
187
|
+
artifact_dir, metadata = self._validated_artifact(artifact_id, nonce)
|
|
188
|
+
self._record_access(artifact_dir, metadata, consume=False)
|
|
189
|
+
return _artifact_read(artifact_dir, metadata)
|
|
190
|
+
|
|
191
|
+
def _validated_artifact(
|
|
192
|
+
self,
|
|
193
|
+
artifact_id: str,
|
|
194
|
+
nonce: str,
|
|
195
|
+
) -> tuple[Path, dict[str, object]]:
|
|
196
|
+
artifact_dir = self._artifact_dir(artifact_id)
|
|
197
|
+
metadata = self._read_metadata(artifact_dir)
|
|
198
|
+
if (
|
|
199
|
+
metadata.get("consumed") is True
|
|
200
|
+
or (artifact_dir / CONSUMED_MARKER).exists()
|
|
201
|
+
):
|
|
202
|
+
raise ArtifactConsumed(f"artifact already consumed: {artifact_id}")
|
|
203
|
+
if self._is_expired(metadata, time.time()):
|
|
204
|
+
shutil.rmtree(artifact_dir, ignore_errors=True)
|
|
205
|
+
raise ArtifactExpired(f"artifact expired: {artifact_id}")
|
|
206
|
+
expected_hash = str(metadata.get("nonce_sha256", ""))
|
|
207
|
+
actual_hash = hashlib.sha256(nonce.encode("utf-8")).hexdigest()
|
|
208
|
+
if not secrets.compare_digest(actual_hash, expected_hash):
|
|
209
|
+
raise ArtifactNotFound(f"artifact not found: {artifact_id}")
|
|
210
|
+
return artifact_dir, metadata
|
|
211
|
+
|
|
212
|
+
def _claim_artifact(self, artifact_dir: Path, artifact_id: str) -> None:
|
|
213
|
+
marker_path = artifact_dir / CONSUMED_MARKER
|
|
214
|
+
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
|
|
215
|
+
try:
|
|
216
|
+
with os.fdopen(os.open(marker_path, flags, 0o600), "wb") as handle:
|
|
217
|
+
handle.write(_iso_timestamp(time.time()).encode("utf-8"))
|
|
218
|
+
except FileExistsError as exc:
|
|
219
|
+
raise ArtifactConsumed(f"artifact already consumed: {artifact_id}") from exc
|
|
220
|
+
if os.name != "nt":
|
|
221
|
+
os.chmod(marker_path, 0o600)
|
|
222
|
+
|
|
223
|
+
def _record_access(
|
|
224
|
+
self,
|
|
225
|
+
artifact_dir: Path,
|
|
226
|
+
metadata: dict[str, object],
|
|
227
|
+
*,
|
|
228
|
+
consume: bool,
|
|
229
|
+
) -> None:
|
|
230
|
+
accessed_at = time.time()
|
|
231
|
+
metadata["last_accessed_at"] = accessed_at
|
|
232
|
+
metadata["last_accessed_at_iso"] = _iso_timestamp(accessed_at)
|
|
233
|
+
metadata["access_count"] = _int_or_default(metadata.get("access_count"), 0) + 1
|
|
234
|
+
if consume:
|
|
235
|
+
metadata["consumed"] = True
|
|
236
|
+
metadata["nonce_sha256"] = None
|
|
237
|
+
_write_private_file(
|
|
238
|
+
artifact_dir / "metadata.json",
|
|
239
|
+
json.dumps(metadata, sort_keys=True).encode("utf-8"),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
def purge_expired(self) -> int:
|
|
243
|
+
if not self._root.exists():
|
|
244
|
+
return 0
|
|
245
|
+
now = time.time()
|
|
246
|
+
removed = 0
|
|
247
|
+
for metadata_path in self._root.glob("*/artifacts/*/metadata.json"):
|
|
248
|
+
artifact_dir = metadata_path.parent
|
|
249
|
+
try:
|
|
250
|
+
metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
|
|
251
|
+
except (OSError, json.JSONDecodeError):
|
|
252
|
+
shutil.rmtree(artifact_dir, ignore_errors=True)
|
|
253
|
+
removed += 1
|
|
254
|
+
continue
|
|
255
|
+
if self._is_expired(metadata, now):
|
|
256
|
+
shutil.rmtree(artifact_dir, ignore_errors=True)
|
|
257
|
+
removed += 1
|
|
258
|
+
return removed
|
|
259
|
+
|
|
260
|
+
def _artifact_dir(self, artifact_id: str) -> Path:
|
|
261
|
+
if ARTIFACT_ID_PATTERN.fullmatch(artifact_id) is None:
|
|
262
|
+
raise ArtifactNotFound(f"artifact not found: {artifact_id}")
|
|
263
|
+
if not self._root.exists():
|
|
264
|
+
raise ArtifactNotFound(f"artifact not found: {artifact_id}")
|
|
265
|
+
matches = [
|
|
266
|
+
plugin_dir / "artifacts" / artifact_id
|
|
267
|
+
for plugin_dir in self._root.iterdir()
|
|
268
|
+
if (plugin_dir / "artifacts" / artifact_id / "metadata.json").is_file()
|
|
269
|
+
]
|
|
270
|
+
if len(matches) != 1:
|
|
271
|
+
raise ArtifactNotFound(f"artifact not found: {artifact_id}")
|
|
272
|
+
return matches[0]
|
|
273
|
+
|
|
274
|
+
def _read_metadata(self, artifact_dir: Path) -> dict[str, object]:
|
|
275
|
+
try:
|
|
276
|
+
raw = (artifact_dir / "metadata.json").read_text(encoding="utf-8")
|
|
277
|
+
metadata = json.loads(raw)
|
|
278
|
+
except (OSError, json.JSONDecodeError) as exc:
|
|
279
|
+
raise ArtifactNotFound("artifact metadata is unavailable") from exc
|
|
280
|
+
if not isinstance(metadata, dict):
|
|
281
|
+
raise ArtifactNotFound("artifact metadata is invalid")
|
|
282
|
+
return metadata
|
|
283
|
+
|
|
284
|
+
def _is_expired(self, metadata: dict[str, object], now: float) -> bool:
|
|
285
|
+
created_at = _float_or_zero(metadata.get("created_at"))
|
|
286
|
+
last_accessed_at = metadata.get("last_accessed_at")
|
|
287
|
+
reference = (
|
|
288
|
+
_float_or_zero(last_accessed_at)
|
|
289
|
+
if isinstance(last_accessed_at, int | float)
|
|
290
|
+
else created_at
|
|
291
|
+
)
|
|
292
|
+
retention_seconds = _int_or_default(
|
|
293
|
+
metadata.get("retention_seconds"),
|
|
294
|
+
self._retention_seconds,
|
|
295
|
+
)
|
|
296
|
+
idle_ttl_seconds = _int_or_default(
|
|
297
|
+
metadata.get("idle_ttl_seconds"),
|
|
298
|
+
self._idle_ttl_seconds,
|
|
299
|
+
)
|
|
300
|
+
return (
|
|
301
|
+
created_at <= 0
|
|
302
|
+
or now - created_at > retention_seconds
|
|
303
|
+
or now - reference > idle_ttl_seconds
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _write_private_file(path: Path, content: bytes) -> None:
|
|
308
|
+
flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
|
|
309
|
+
with os.fdopen(os.open(path, flags, 0o600), "wb") as handle:
|
|
310
|
+
handle.write(content)
|
|
311
|
+
if os.name != "nt":
|
|
312
|
+
os.chmod(path, 0o600)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def _iso_timestamp(value: float) -> str:
|
|
316
|
+
return datetime.fromtimestamp(value, tz=timezone.utc).isoformat()
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def _float_or_zero(value: object) -> float:
|
|
320
|
+
if isinstance(value, int | float):
|
|
321
|
+
return float(value)
|
|
322
|
+
return 0.0
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def _int_or_default(value: object, default: int) -> int:
|
|
326
|
+
if isinstance(value, int):
|
|
327
|
+
return value
|
|
328
|
+
return default
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _artifact_read(artifact_dir: Path, metadata: dict[str, object]) -> ArtifactRead:
|
|
332
|
+
return ArtifactRead(
|
|
333
|
+
path=artifact_dir / "payload",
|
|
334
|
+
filename=_string_or_none(metadata.get("filename")),
|
|
335
|
+
content_type=str(metadata.get("content_type", "application/octet-stream")),
|
|
336
|
+
size=_int_or_default(metadata.get("size"), 0),
|
|
337
|
+
sha256=str(metadata.get("sha256", "")),
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def _string_or_none(value: object) -> str | None:
|
|
342
|
+
if isinstance(value, str):
|
|
343
|
+
return value
|
|
344
|
+
return None
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from collections.abc import Iterable
|
|
5
|
+
from typing import TextIO
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def format_cli_error(
|
|
9
|
+
message: str,
|
|
10
|
+
*,
|
|
11
|
+
area: str | None = None,
|
|
12
|
+
details: Iterable[str] = (),
|
|
13
|
+
) -> str:
|
|
14
|
+
area_text = f" {area}" if area else ""
|
|
15
|
+
message_lines = message.splitlines() or [""]
|
|
16
|
+
lines = [f"Arbiter{area_text} error: {message_lines[0]}"]
|
|
17
|
+
lines.extend(f" {line}" for line in message_lines[1:])
|
|
18
|
+
lines.extend(f" {detail}" for detail in details)
|
|
19
|
+
return "\n".join(lines)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def print_cli_error(
|
|
23
|
+
message: str,
|
|
24
|
+
*,
|
|
25
|
+
area: str | None = None,
|
|
26
|
+
details: Iterable[str] = (),
|
|
27
|
+
file: TextIO | None = None,
|
|
28
|
+
) -> None:
|
|
29
|
+
if file is None:
|
|
30
|
+
file = sys.stderr
|
|
31
|
+
print(format_cli_error(message, area=area, details=details), file=file)
|
arbiter_server/config.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from importlib.metadata import entry_points
|
|
7
|
+
import logging
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, cast
|
|
10
|
+
|
|
11
|
+
from hydra.core.config_store import ConfigStore
|
|
12
|
+
from omegaconf import II, OmegaConf
|
|
13
|
+
|
|
14
|
+
from .services import SERVICE_PLUGIN_ENTRY_POINT_GROUP, ServicePluginFactory
|
|
15
|
+
from .services import validate_service_plugin_compatibility
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
LOGGER = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class HTTPServerConfig:
|
|
23
|
+
scheme: str = "http"
|
|
24
|
+
host: str = "127.0.0.1"
|
|
25
|
+
port: int = 8000
|
|
26
|
+
path: str = "/mcp"
|
|
27
|
+
base_url: str = "${.scheme}://${.host}:${.port}"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _bind_http_server_config() -> HTTPServerConfig:
|
|
31
|
+
return HTTPServerConfig()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _public_http_server_config() -> HTTPServerConfig:
|
|
35
|
+
return HTTPServerConfig(
|
|
36
|
+
host="127.0.0.1",
|
|
37
|
+
port=II("arbiter.server.bind.port"),
|
|
38
|
+
path=II("arbiter.server.bind.path"),
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class FastMCPConfig:
|
|
44
|
+
name: str = "arbiter"
|
|
45
|
+
transport: str = "streamable-http"
|
|
46
|
+
bind: HTTPServerConfig = field(default_factory=_bind_http_server_config)
|
|
47
|
+
public: HTTPServerConfig = field(default_factory=_public_http_server_config)
|
|
48
|
+
stateless_http: bool = True
|
|
49
|
+
json_response: bool = True
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class DiscoveryConfig:
|
|
54
|
+
max_account_preview_limit: int = 25
|
|
55
|
+
max_operation_preview_limit: int = 25
|
|
56
|
+
|
|
57
|
+
def __post_init__(self) -> None:
|
|
58
|
+
if self.max_account_preview_limit < 1:
|
|
59
|
+
raise ValueError("max_account_preview_limit must be >= 1")
|
|
60
|
+
if self.max_operation_preview_limit < 1:
|
|
61
|
+
raise ValueError("max_operation_preview_limit must be >= 1")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class StorageConfig:
|
|
66
|
+
plugin_data_dir: str | None = None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class DeploymentScope(str, Enum):
|
|
70
|
+
unknown = "unknown"
|
|
71
|
+
staged = "staged"
|
|
72
|
+
installed = "installed"
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class Policy:
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass
|
|
81
|
+
class ArbiterConfig:
|
|
82
|
+
env_file: str | None = None
|
|
83
|
+
server: FastMCPConfig = field(default_factory=FastMCPConfig)
|
|
84
|
+
deployment_scope: DeploymentScope = DeploymentScope.unknown
|
|
85
|
+
discovery: DiscoveryConfig = field(default_factory=DiscoveryConfig)
|
|
86
|
+
storage: StorageConfig = field(default_factory=StorageConfig)
|
|
87
|
+
account: dict[str, Any] = field(default_factory=dict)
|
|
88
|
+
policy: dict[str, Any] = field(default_factory=dict)
|
|
89
|
+
etc: dict[str, Any] = field(default_factory=dict)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass
|
|
93
|
+
class AppConfig:
|
|
94
|
+
arbiter: ArbiterConfig = field(default_factory=ArbiterConfig)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _config_mapping_items(config: Any) -> list[tuple[str, object]]:
|
|
98
|
+
if isinstance(config, Mapping):
|
|
99
|
+
return [
|
|
100
|
+
(str(service_name), service_config)
|
|
101
|
+
for service_name, service_config in config.items()
|
|
102
|
+
if not str(service_name).startswith("_")
|
|
103
|
+
]
|
|
104
|
+
return [
|
|
105
|
+
(service_name, service_config)
|
|
106
|
+
for service_name, service_config in vars(config).items()
|
|
107
|
+
if not service_name.startswith("_")
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _service_config_mapping(config: Any, service_name: str) -> Mapping[str, object]:
|
|
112
|
+
if isinstance(config, Mapping):
|
|
113
|
+
value = config.get(service_name, {})
|
|
114
|
+
else:
|
|
115
|
+
value = getattr(config, service_name, {})
|
|
116
|
+
if isinstance(value, Mapping):
|
|
117
|
+
return value
|
|
118
|
+
raise TypeError(f"service config must be a mapping: {service_name}")
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def configured_service_names(accounts: Any) -> list[str]:
|
|
122
|
+
return [
|
|
123
|
+
service_name
|
|
124
|
+
for service_name, service_accounts in _config_mapping_items(accounts)
|
|
125
|
+
if service_accounts
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def service_accounts_for(
|
|
130
|
+
config: Any,
|
|
131
|
+
service_name: str,
|
|
132
|
+
) -> Mapping[str, object] | None:
|
|
133
|
+
accounts = _service_config_mapping(config.arbiter.account, service_name)
|
|
134
|
+
if not accounts:
|
|
135
|
+
return None
|
|
136
|
+
return accounts
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def service_policies_for(
|
|
140
|
+
config: Any,
|
|
141
|
+
service_name: str,
|
|
142
|
+
) -> Mapping[str, object]:
|
|
143
|
+
return _service_config_mapping(config.arbiter.policy, service_name)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
_CONFIG_SCHEMA_NAMES = ("arbiter_app_config_schema",)
|
|
147
|
+
_CONFIG_REGISTERED = False
|
|
148
|
+
_RESOLVERS_REGISTERED = False
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _read_secret_file(path: str) -> str:
|
|
152
|
+
secret_path = Path(path).expanduser()
|
|
153
|
+
try:
|
|
154
|
+
return secret_path.read_text(encoding="utf-8").rstrip("\r\n")
|
|
155
|
+
except OSError as exc:
|
|
156
|
+
raise ValueError(f"failed to read secret file: {secret_path}") from exc
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _register_resolvers() -> None:
|
|
160
|
+
global _RESOLVERS_REGISTERED
|
|
161
|
+
if _RESOLVERS_REGISTERED:
|
|
162
|
+
return
|
|
163
|
+
if not OmegaConf.has_resolver("secret_file"):
|
|
164
|
+
OmegaConf.register_new_resolver(
|
|
165
|
+
"secret_file",
|
|
166
|
+
_read_secret_file,
|
|
167
|
+
use_cache=False,
|
|
168
|
+
)
|
|
169
|
+
_RESOLVERS_REGISTERED = True
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _register_server_configs(config_store: ConfigStore) -> None:
|
|
173
|
+
for schema_name in _CONFIG_SCHEMA_NAMES:
|
|
174
|
+
config_store.store(name=schema_name, node=AppConfig)
|
|
175
|
+
|
|
176
|
+
config_store.store(
|
|
177
|
+
group="arbiter/server",
|
|
178
|
+
name="schema",
|
|
179
|
+
node=FastMCPConfig,
|
|
180
|
+
package="arbiter.server",
|
|
181
|
+
provider="arbiter-server",
|
|
182
|
+
)
|
|
183
|
+
config_store.store(
|
|
184
|
+
group="arbiter/server",
|
|
185
|
+
name="streamable-http",
|
|
186
|
+
node=FastMCPConfig(),
|
|
187
|
+
package="arbiter.server",
|
|
188
|
+
provider="arbiter-server",
|
|
189
|
+
)
|
|
190
|
+
config_store.store(
|
|
191
|
+
group="arbiter/server",
|
|
192
|
+
name="stdio",
|
|
193
|
+
node=FastMCPConfig(transport="stdio"),
|
|
194
|
+
package="arbiter.server",
|
|
195
|
+
provider="arbiter-server",
|
|
196
|
+
)
|
|
197
|
+
config_store.store(
|
|
198
|
+
group="arbiter/server",
|
|
199
|
+
name="sse",
|
|
200
|
+
node=FastMCPConfig(transport="sse"),
|
|
201
|
+
package="arbiter.server",
|
|
202
|
+
provider="arbiter-server",
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _register_service_plugin_configs(config_store: ConfigStore) -> None:
|
|
207
|
+
for entry_point in entry_points().select(group=SERVICE_PLUGIN_ENTRY_POINT_GROUP):
|
|
208
|
+
try:
|
|
209
|
+
plugin_factory = cast(ServicePluginFactory, entry_point.load())
|
|
210
|
+
except ModuleNotFoundError as exc:
|
|
211
|
+
LOGGER.warning(
|
|
212
|
+
"Skipping unavailable service plugin config entry point %s=%s: %s",
|
|
213
|
+
entry_point.name,
|
|
214
|
+
entry_point.value,
|
|
215
|
+
exc,
|
|
216
|
+
)
|
|
217
|
+
continue
|
|
218
|
+
service_plugin = plugin_factory()
|
|
219
|
+
validate_service_plugin_compatibility(service_plugin)
|
|
220
|
+
service_plugin.register_configs(config_store)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def register_configs() -> None:
|
|
224
|
+
global _CONFIG_REGISTERED
|
|
225
|
+
_register_resolvers()
|
|
226
|
+
if _CONFIG_REGISTERED:
|
|
227
|
+
return
|
|
228
|
+
config_store = ConfigStore.instance()
|
|
229
|
+
_register_server_configs(config_store)
|
|
230
|
+
_register_service_plugin_configs(config_store)
|
|
231
|
+
_CONFIG_REGISTERED = True
|