tactus 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tactus/__init__.py +49 -0
- tactus/adapters/__init__.py +9 -0
- tactus/adapters/broker_log.py +76 -0
- tactus/adapters/cli_hitl.py +189 -0
- tactus/adapters/cli_log.py +223 -0
- tactus/adapters/cost_collector_log.py +56 -0
- tactus/adapters/file_storage.py +367 -0
- tactus/adapters/http_callback_log.py +109 -0
- tactus/adapters/ide_log.py +71 -0
- tactus/adapters/lua_tools.py +336 -0
- tactus/adapters/mcp.py +289 -0
- tactus/adapters/mcp_manager.py +196 -0
- tactus/adapters/memory.py +53 -0
- tactus/adapters/plugins.py +419 -0
- tactus/backends/http_backend.py +58 -0
- tactus/backends/model_backend.py +35 -0
- tactus/backends/pytorch_backend.py +110 -0
- tactus/broker/__init__.py +12 -0
- tactus/broker/client.py +247 -0
- tactus/broker/protocol.py +183 -0
- tactus/broker/server.py +1123 -0
- tactus/broker/stdio.py +12 -0
- tactus/cli/__init__.py +7 -0
- tactus/cli/app.py +2245 -0
- tactus/cli/commands/__init__.py +0 -0
- tactus/core/__init__.py +32 -0
- tactus/core/config_manager.py +790 -0
- tactus/core/dependencies/__init__.py +14 -0
- tactus/core/dependencies/registry.py +180 -0
- tactus/core/dsl_stubs.py +2117 -0
- tactus/core/exceptions.py +66 -0
- tactus/core/execution_context.py +480 -0
- tactus/core/lua_sandbox.py +508 -0
- tactus/core/message_history_manager.py +236 -0
- tactus/core/mocking.py +286 -0
- tactus/core/output_validator.py +291 -0
- tactus/core/registry.py +499 -0
- tactus/core/runtime.py +2907 -0
- tactus/core/template_resolver.py +142 -0
- tactus/core/yaml_parser.py +301 -0
- tactus/docker/Dockerfile +61 -0
- tactus/docker/entrypoint.sh +69 -0
- tactus/dspy/__init__.py +39 -0
- tactus/dspy/agent.py +1144 -0
- tactus/dspy/broker_lm.py +181 -0
- tactus/dspy/config.py +212 -0
- tactus/dspy/history.py +196 -0
- tactus/dspy/module.py +405 -0
- tactus/dspy/prediction.py +318 -0
- tactus/dspy/signature.py +185 -0
- tactus/formatting/__init__.py +7 -0
- tactus/formatting/formatter.py +437 -0
- tactus/ide/__init__.py +9 -0
- tactus/ide/coding_assistant.py +343 -0
- tactus/ide/server.py +2223 -0
- tactus/primitives/__init__.py +49 -0
- tactus/primitives/control.py +168 -0
- tactus/primitives/file.py +229 -0
- tactus/primitives/handles.py +378 -0
- tactus/primitives/host.py +94 -0
- tactus/primitives/human.py +342 -0
- tactus/primitives/json.py +189 -0
- tactus/primitives/log.py +187 -0
- tactus/primitives/message_history.py +157 -0
- tactus/primitives/model.py +163 -0
- tactus/primitives/procedure.py +564 -0
- tactus/primitives/procedure_callable.py +318 -0
- tactus/primitives/retry.py +155 -0
- tactus/primitives/session.py +152 -0
- tactus/primitives/state.py +182 -0
- tactus/primitives/step.py +209 -0
- tactus/primitives/system.py +93 -0
- tactus/primitives/tool.py +375 -0
- tactus/primitives/tool_handle.py +279 -0
- tactus/primitives/toolset.py +229 -0
- tactus/protocols/__init__.py +38 -0
- tactus/protocols/chat_recorder.py +81 -0
- tactus/protocols/config.py +97 -0
- tactus/protocols/cost.py +31 -0
- tactus/protocols/hitl.py +71 -0
- tactus/protocols/log_handler.py +27 -0
- tactus/protocols/models.py +355 -0
- tactus/protocols/result.py +33 -0
- tactus/protocols/storage.py +90 -0
- tactus/providers/__init__.py +13 -0
- tactus/providers/base.py +92 -0
- tactus/providers/bedrock.py +117 -0
- tactus/providers/google.py +105 -0
- tactus/providers/openai.py +98 -0
- tactus/sandbox/__init__.py +63 -0
- tactus/sandbox/config.py +171 -0
- tactus/sandbox/container_runner.py +1099 -0
- tactus/sandbox/docker_manager.py +433 -0
- tactus/sandbox/entrypoint.py +227 -0
- tactus/sandbox/protocol.py +213 -0
- tactus/stdlib/__init__.py +10 -0
- tactus/stdlib/io/__init__.py +13 -0
- tactus/stdlib/io/csv.py +88 -0
- tactus/stdlib/io/excel.py +136 -0
- tactus/stdlib/io/file.py +90 -0
- tactus/stdlib/io/fs.py +154 -0
- tactus/stdlib/io/hdf5.py +121 -0
- tactus/stdlib/io/json.py +109 -0
- tactus/stdlib/io/parquet.py +83 -0
- tactus/stdlib/io/tsv.py +88 -0
- tactus/stdlib/loader.py +274 -0
- tactus/stdlib/tac/tactus/tools/done.tac +33 -0
- tactus/stdlib/tac/tactus/tools/log.tac +50 -0
- tactus/testing/README.md +273 -0
- tactus/testing/__init__.py +61 -0
- tactus/testing/behave_integration.py +380 -0
- tactus/testing/context.py +486 -0
- tactus/testing/eval_models.py +114 -0
- tactus/testing/evaluation_runner.py +222 -0
- tactus/testing/evaluators.py +634 -0
- tactus/testing/events.py +94 -0
- tactus/testing/gherkin_parser.py +134 -0
- tactus/testing/mock_agent.py +315 -0
- tactus/testing/mock_dependencies.py +234 -0
- tactus/testing/mock_hitl.py +171 -0
- tactus/testing/mock_registry.py +168 -0
- tactus/testing/mock_tools.py +133 -0
- tactus/testing/models.py +115 -0
- tactus/testing/pydantic_eval_runner.py +508 -0
- tactus/testing/steps/__init__.py +13 -0
- tactus/testing/steps/builtin.py +902 -0
- tactus/testing/steps/custom.py +69 -0
- tactus/testing/steps/registry.py +68 -0
- tactus/testing/test_runner.py +489 -0
- tactus/tracing/__init__.py +5 -0
- tactus/tracing/trace_manager.py +417 -0
- tactus/utils/__init__.py +1 -0
- tactus/utils/cost_calculator.py +72 -0
- tactus/utils/model_pricing.py +132 -0
- tactus/utils/safe_file_library.py +502 -0
- tactus/utils/safe_libraries.py +234 -0
- tactus/validation/LuaLexerBase.py +66 -0
- tactus/validation/LuaParserBase.py +23 -0
- tactus/validation/README.md +224 -0
- tactus/validation/__init__.py +7 -0
- tactus/validation/error_listener.py +21 -0
- tactus/validation/generated/LuaLexer.interp +231 -0
- tactus/validation/generated/LuaLexer.py +5548 -0
- tactus/validation/generated/LuaLexer.tokens +124 -0
- tactus/validation/generated/LuaLexerBase.py +66 -0
- tactus/validation/generated/LuaParser.interp +173 -0
- tactus/validation/generated/LuaParser.py +6439 -0
- tactus/validation/generated/LuaParser.tokens +124 -0
- tactus/validation/generated/LuaParserBase.py +23 -0
- tactus/validation/generated/LuaParserVisitor.py +118 -0
- tactus/validation/generated/__init__.py +7 -0
- tactus/validation/grammar/LuaLexer.g4 +123 -0
- tactus/validation/grammar/LuaParser.g4 +178 -0
- tactus/validation/semantic_visitor.py +817 -0
- tactus/validation/validator.py +157 -0
- tactus-0.31.0.dist-info/METADATA +1809 -0
- tactus-0.31.0.dist-info/RECORD +160 -0
- tactus-0.31.0.dist-info/WHEEL +4 -0
- tactus-0.31.0.dist-info/entry_points.txt +2 -0
- tactus-0.31.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HTTP model backend for REST endpoint inference.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class HTTPModelBackend:
|
|
15
|
+
"""Model backend that calls HTTP REST endpoints."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, endpoint: str, timeout: float = 30.0, headers: dict | None = None):
|
|
18
|
+
"""
|
|
19
|
+
Initialize HTTP model backend.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
endpoint: URL of the inference endpoint
|
|
23
|
+
timeout: Request timeout in seconds
|
|
24
|
+
headers: Optional HTTP headers to include
|
|
25
|
+
"""
|
|
26
|
+
self.endpoint = endpoint
|
|
27
|
+
self.timeout = timeout
|
|
28
|
+
self.headers = headers or {}
|
|
29
|
+
|
|
30
|
+
async def predict(self, input_data: Any) -> Any:
|
|
31
|
+
"""
|
|
32
|
+
Call HTTP endpoint with input data.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
input_data: Data to send to endpoint (will be JSON serialized)
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Response JSON from endpoint
|
|
39
|
+
"""
|
|
40
|
+
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
41
|
+
response = await client.post(self.endpoint, json=input_data, headers=self.headers)
|
|
42
|
+
response.raise_for_status()
|
|
43
|
+
return response.json()
|
|
44
|
+
|
|
45
|
+
def predict_sync(self, input_data: Any) -> Any:
|
|
46
|
+
"""
|
|
47
|
+
Synchronous version of predict.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
input_data: Data to send to endpoint (will be JSON serialized)
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Response JSON from endpoint
|
|
54
|
+
"""
|
|
55
|
+
with httpx.Client(timeout=self.timeout) as client:
|
|
56
|
+
response = client.post(self.endpoint, json=input_data, headers=self.headers)
|
|
57
|
+
response.raise_for_status()
|
|
58
|
+
return response.json()
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model backend protocol for ML inference.
|
|
3
|
+
|
|
4
|
+
Defines the interface that all model backends must implement.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any, Protocol
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ModelBackend(Protocol):
|
|
11
|
+
"""Protocol for model inference backends."""
|
|
12
|
+
|
|
13
|
+
async def predict(self, input_data: Any) -> Any:
|
|
14
|
+
"""
|
|
15
|
+
Run inference on input data.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
input_data: Input to the model (format depends on backend)
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
Model prediction result
|
|
22
|
+
"""
|
|
23
|
+
...
|
|
24
|
+
|
|
25
|
+
def predict_sync(self, input_data: Any) -> Any:
|
|
26
|
+
"""
|
|
27
|
+
Synchronous version of predict.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
input_data: Input to the model (format depends on backend)
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Model prediction result
|
|
34
|
+
"""
|
|
35
|
+
...
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PyTorch model backend for .pt file inference.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PyTorchModelBackend:
|
|
13
|
+
"""Model backend that loads and runs PyTorch models."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, path: str, device: str = "cpu", labels: list[str] | None = None):
|
|
16
|
+
"""
|
|
17
|
+
Initialize PyTorch model backend.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
path: Path to .pt model file
|
|
21
|
+
device: Device to run on ('cpu', 'cuda', 'mps')
|
|
22
|
+
labels: Optional list of class labels for classification
|
|
23
|
+
"""
|
|
24
|
+
self.path = Path(path)
|
|
25
|
+
self.device = device
|
|
26
|
+
self.labels = labels
|
|
27
|
+
self.model = None
|
|
28
|
+
|
|
29
|
+
def _load_model(self):
|
|
30
|
+
"""Lazy load the model on first use."""
|
|
31
|
+
if self.model is not None:
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
import torch
|
|
36
|
+
except ImportError:
|
|
37
|
+
raise ImportError("PyTorch not installed. Install with: pip install torch")
|
|
38
|
+
|
|
39
|
+
if not self.path.exists():
|
|
40
|
+
raise FileNotFoundError(f"Model file not found: {self.path}")
|
|
41
|
+
|
|
42
|
+
self.model = torch.load(self.path, map_location=self.device)
|
|
43
|
+
self.model.eval()
|
|
44
|
+
logger.info(f"Loaded PyTorch model from {self.path}")
|
|
45
|
+
|
|
46
|
+
async def predict(self, input_data: Any) -> Any:
|
|
47
|
+
"""
|
|
48
|
+
Run PyTorch model inference.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
input_data: Input tensor or data convertible to tensor
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Model output (tensor or label if labels provided)
|
|
55
|
+
"""
|
|
56
|
+
return self.predict_sync(input_data)
|
|
57
|
+
|
|
58
|
+
def predict_sync(self, input_data: Any) -> Any:
|
|
59
|
+
"""
|
|
60
|
+
Synchronous PyTorch inference.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
input_data: Input tensor or data convertible to tensor
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Model output (tensor or label if labels provided)
|
|
67
|
+
"""
|
|
68
|
+
self._load_model()
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
import torch
|
|
72
|
+
except ImportError:
|
|
73
|
+
raise ImportError("PyTorch not installed. Install with: pip install torch")
|
|
74
|
+
|
|
75
|
+
# Convert input to tensor if needed
|
|
76
|
+
if not isinstance(input_data, torch.Tensor):
|
|
77
|
+
if isinstance(input_data, (list, tuple)):
|
|
78
|
+
input_tensor = torch.tensor(input_data)
|
|
79
|
+
else:
|
|
80
|
+
input_tensor = torch.tensor([input_data])
|
|
81
|
+
else:
|
|
82
|
+
input_tensor = input_data
|
|
83
|
+
|
|
84
|
+
# Move to device
|
|
85
|
+
input_tensor = input_tensor.to(self.device)
|
|
86
|
+
|
|
87
|
+
# Run inference
|
|
88
|
+
with torch.no_grad():
|
|
89
|
+
output = self.model(input_tensor)
|
|
90
|
+
|
|
91
|
+
# If labels provided, return class label
|
|
92
|
+
if self.labels:
|
|
93
|
+
if output.dim() > 1:
|
|
94
|
+
# Classification - get argmax
|
|
95
|
+
predicted_idx = output.argmax(dim=-1).item()
|
|
96
|
+
else:
|
|
97
|
+
# Single value - round to nearest index
|
|
98
|
+
predicted_idx = int(round(output.item()))
|
|
99
|
+
|
|
100
|
+
if 0 <= predicted_idx < len(self.labels):
|
|
101
|
+
return self.labels[predicted_idx]
|
|
102
|
+
else:
|
|
103
|
+
logger.warning(f"Predicted index {predicted_idx} out of range for labels")
|
|
104
|
+
return predicted_idx
|
|
105
|
+
|
|
106
|
+
# Return raw output
|
|
107
|
+
if output.numel() == 1:
|
|
108
|
+
return output.item()
|
|
109
|
+
else:
|
|
110
|
+
return output.cpu().numpy().tolist()
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Brokered capabilities for Tactus.
|
|
3
|
+
|
|
4
|
+
The broker is a trusted host-side process that holds credentials and performs
|
|
5
|
+
privileged operations (e.g., LLM API calls) on behalf of a secretless, networkless
|
|
6
|
+
runtime container.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from tactus.broker.client import BrokerClient
|
|
10
|
+
from tactus.broker.server import BrokerServer, TcpBrokerServer
|
|
11
|
+
|
|
12
|
+
__all__ = ["BrokerClient", "BrokerServer", "TcpBrokerServer"]
|
tactus/broker/client.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Broker client for use inside the runtime container.
|
|
3
|
+
|
|
4
|
+
Uses a broker transport selected at runtime:
|
|
5
|
+
- `stdio` (recommended for Docker Desktop): requests are written to stderr with a marker and
|
|
6
|
+
responses are read from stdin as NDJSON.
|
|
7
|
+
- Unix domain sockets (UDS): retained for non-Docker/host testing.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
import ssl
|
|
14
|
+
import sys
|
|
15
|
+
import threading
|
|
16
|
+
import uuid
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Any, AsyncIterator, Optional
|
|
19
|
+
|
|
20
|
+
from tactus.broker.protocol import read_message, write_message
|
|
21
|
+
from tactus.broker.stdio import STDIO_REQUEST_PREFIX, STDIO_TRANSPORT_VALUE
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _json_dumps(obj: Any) -> str:
|
|
25
|
+
return json.dumps(obj, ensure_ascii=False, separators=(",", ":"))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class _StdioBrokerTransport:
|
|
29
|
+
def __init__(self):
|
|
30
|
+
self._write_lock = threading.Lock()
|
|
31
|
+
self._pending: dict[
|
|
32
|
+
str, tuple[asyncio.AbstractEventLoop, asyncio.Queue[dict[str, Any]]]
|
|
33
|
+
] = {}
|
|
34
|
+
self._pending_lock = threading.Lock()
|
|
35
|
+
self._reader_thread: Optional[threading.Thread] = None
|
|
36
|
+
self._stop = threading.Event()
|
|
37
|
+
|
|
38
|
+
def _ensure_reader_thread(self) -> None:
|
|
39
|
+
if self._reader_thread is not None and self._reader_thread.is_alive():
|
|
40
|
+
return
|
|
41
|
+
|
|
42
|
+
self._reader_thread = threading.Thread(
|
|
43
|
+
target=self._read_loop,
|
|
44
|
+
name="tactus-broker-stdio-reader",
|
|
45
|
+
daemon=True,
|
|
46
|
+
)
|
|
47
|
+
self._reader_thread.start()
|
|
48
|
+
|
|
49
|
+
def _read_loop(self) -> None:
|
|
50
|
+
while not self._stop.is_set():
|
|
51
|
+
line = sys.stdin.buffer.readline()
|
|
52
|
+
if not line:
|
|
53
|
+
return
|
|
54
|
+
try:
|
|
55
|
+
event = json.loads(line.decode("utf-8"))
|
|
56
|
+
except json.JSONDecodeError:
|
|
57
|
+
continue
|
|
58
|
+
|
|
59
|
+
req_id = event.get("id")
|
|
60
|
+
if not isinstance(req_id, str):
|
|
61
|
+
continue
|
|
62
|
+
|
|
63
|
+
with self._pending_lock:
|
|
64
|
+
pending = self._pending.get(req_id)
|
|
65
|
+
if pending is None:
|
|
66
|
+
continue
|
|
67
|
+
|
|
68
|
+
loop, queue = pending
|
|
69
|
+
try:
|
|
70
|
+
loop.call_soon_threadsafe(queue.put_nowait, event)
|
|
71
|
+
except RuntimeError:
|
|
72
|
+
# Loop is closed or unavailable; ignore.
|
|
73
|
+
continue
|
|
74
|
+
|
|
75
|
+
async def aclose(self) -> None:
|
|
76
|
+
self._stop.set()
|
|
77
|
+
thread = self._reader_thread
|
|
78
|
+
if thread is None or not thread.is_alive():
|
|
79
|
+
return
|
|
80
|
+
try:
|
|
81
|
+
await asyncio.to_thread(thread.join, 0.5)
|
|
82
|
+
except Exception:
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
async def request(
|
|
86
|
+
self, req_id: str, method: str, params: dict[str, Any]
|
|
87
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
88
|
+
self._ensure_reader_thread()
|
|
89
|
+
loop = asyncio.get_running_loop()
|
|
90
|
+
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
|
91
|
+
with self._pending_lock:
|
|
92
|
+
self._pending[req_id] = (loop, queue)
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
payload = _json_dumps({"id": req_id, "method": method, "params": params})
|
|
96
|
+
with self._write_lock:
|
|
97
|
+
sys.stderr.write(f"{STDIO_REQUEST_PREFIX}{payload}\n")
|
|
98
|
+
sys.stderr.flush()
|
|
99
|
+
|
|
100
|
+
while True:
|
|
101
|
+
event = await queue.get()
|
|
102
|
+
yield event
|
|
103
|
+
if event.get("event") in ("done", "error"):
|
|
104
|
+
return
|
|
105
|
+
finally:
|
|
106
|
+
with self._pending_lock:
|
|
107
|
+
self._pending.pop(req_id, None)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
_STDIO_TRANSPORT = _StdioBrokerTransport()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
async def close_stdio_transport() -> None:
|
|
114
|
+
await _STDIO_TRANSPORT.aclose()
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class BrokerClient:
|
|
118
|
+
def __init__(self, socket_path: str | Path):
|
|
119
|
+
self.socket_path = str(socket_path)
|
|
120
|
+
|
|
121
|
+
@classmethod
|
|
122
|
+
def from_environment(cls) -> Optional["BrokerClient"]:
|
|
123
|
+
socket_path = os.environ.get("TACTUS_BROKER_SOCKET")
|
|
124
|
+
if not socket_path:
|
|
125
|
+
return None
|
|
126
|
+
return cls(socket_path)
|
|
127
|
+
|
|
128
|
+
async def _request(self, method: str, params: dict[str, Any]) -> AsyncIterator[dict[str, Any]]:
|
|
129
|
+
req_id = uuid.uuid4().hex
|
|
130
|
+
|
|
131
|
+
if self.socket_path == STDIO_TRANSPORT_VALUE:
|
|
132
|
+
async for event in _STDIO_TRANSPORT.request(req_id, method, params):
|
|
133
|
+
# Responses are already correlated by req_id; add a defensive filter anyway.
|
|
134
|
+
if event.get("id") == req_id:
|
|
135
|
+
yield event
|
|
136
|
+
return
|
|
137
|
+
|
|
138
|
+
if self.socket_path.startswith(("tcp://", "tls://")):
|
|
139
|
+
use_tls = self.socket_path.startswith("tls://")
|
|
140
|
+
host_port = self.socket_path.split("://", 1)[1]
|
|
141
|
+
if "/" in host_port:
|
|
142
|
+
host_port = host_port.split("/", 1)[0]
|
|
143
|
+
if ":" not in host_port:
|
|
144
|
+
raise ValueError(
|
|
145
|
+
f"Invalid broker endpoint: {self.socket_path}. Expected tcp://host:port or tls://host:port"
|
|
146
|
+
)
|
|
147
|
+
host, port_str = host_port.rsplit(":", 1)
|
|
148
|
+
try:
|
|
149
|
+
port = int(port_str)
|
|
150
|
+
except ValueError as e:
|
|
151
|
+
raise ValueError(f"Invalid broker port in endpoint: {self.socket_path}") from e
|
|
152
|
+
|
|
153
|
+
ssl_ctx: ssl.SSLContext | None = None
|
|
154
|
+
if use_tls:
|
|
155
|
+
ssl_ctx = ssl.create_default_context()
|
|
156
|
+
cafile = os.environ.get("TACTUS_BROKER_TLS_CA_FILE")
|
|
157
|
+
if cafile:
|
|
158
|
+
ssl_ctx.load_verify_locations(cafile=cafile)
|
|
159
|
+
|
|
160
|
+
if os.environ.get("TACTUS_BROKER_TLS_INSECURE") in ("1", "true", "yes"):
|
|
161
|
+
ssl_ctx.check_hostname = False
|
|
162
|
+
ssl_ctx.verify_mode = ssl.CERT_NONE
|
|
163
|
+
|
|
164
|
+
reader, writer = await asyncio.open_connection(host, port, ssl=ssl_ctx)
|
|
165
|
+
await write_message(writer, {"id": req_id, "method": method, "params": params})
|
|
166
|
+
|
|
167
|
+
try:
|
|
168
|
+
while True:
|
|
169
|
+
event = await read_message(reader)
|
|
170
|
+
if event.get("id") != req_id:
|
|
171
|
+
continue
|
|
172
|
+
yield event
|
|
173
|
+
if event.get("event") in ("done", "error"):
|
|
174
|
+
return
|
|
175
|
+
finally:
|
|
176
|
+
try:
|
|
177
|
+
writer.close()
|
|
178
|
+
await writer.wait_closed()
|
|
179
|
+
except Exception:
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
reader, writer = await asyncio.open_unix_connection(self.socket_path)
|
|
183
|
+
await write_message(writer, {"id": req_id, "method": method, "params": params})
|
|
184
|
+
|
|
185
|
+
try:
|
|
186
|
+
while True:
|
|
187
|
+
event = await read_message(reader)
|
|
188
|
+
# Ignore unrelated messages (defensive; current server is 1-req/conn).
|
|
189
|
+
if event.get("id") != req_id:
|
|
190
|
+
continue
|
|
191
|
+
yield event
|
|
192
|
+
if event.get("event") in ("done", "error"):
|
|
193
|
+
return
|
|
194
|
+
finally:
|
|
195
|
+
try:
|
|
196
|
+
writer.close()
|
|
197
|
+
await writer.wait_closed()
|
|
198
|
+
except Exception:
|
|
199
|
+
pass
|
|
200
|
+
|
|
201
|
+
def llm_chat(
|
|
202
|
+
self,
|
|
203
|
+
*,
|
|
204
|
+
provider: str,
|
|
205
|
+
model: str,
|
|
206
|
+
messages: list[dict[str, Any]],
|
|
207
|
+
temperature: Optional[float] = None,
|
|
208
|
+
max_tokens: Optional[int] = None,
|
|
209
|
+
stream: bool,
|
|
210
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
211
|
+
params: dict[str, Any] = {
|
|
212
|
+
"provider": provider,
|
|
213
|
+
"model": model,
|
|
214
|
+
"messages": messages,
|
|
215
|
+
"stream": stream,
|
|
216
|
+
}
|
|
217
|
+
if temperature is not None:
|
|
218
|
+
params["temperature"] = temperature
|
|
219
|
+
if max_tokens is not None:
|
|
220
|
+
params["max_tokens"] = max_tokens
|
|
221
|
+
return self._request("llm.chat", params)
|
|
222
|
+
|
|
223
|
+
async def call_tool(self, *, name: str, args: dict[str, Any]) -> Any:
|
|
224
|
+
"""
|
|
225
|
+
Call an allowlisted host tool via the broker.
|
|
226
|
+
|
|
227
|
+
Returns the decoded `result` payload from the broker.
|
|
228
|
+
"""
|
|
229
|
+
if not isinstance(name, str) or not name:
|
|
230
|
+
raise ValueError("tool name must be a non-empty string")
|
|
231
|
+
if not isinstance(args, dict):
|
|
232
|
+
raise ValueError("tool args must be an object")
|
|
233
|
+
|
|
234
|
+
async for event in self._request("tool.call", {"name": name, "args": args}):
|
|
235
|
+
event_type = event.get("event")
|
|
236
|
+
if event_type == "done":
|
|
237
|
+
data = event.get("data") or {}
|
|
238
|
+
return data.get("result")
|
|
239
|
+
if event_type == "error":
|
|
240
|
+
err = event.get("error") or {}
|
|
241
|
+
raise RuntimeError(err.get("message") or "Broker tool error")
|
|
242
|
+
|
|
243
|
+
raise RuntimeError("Broker tool call ended without a response")
|
|
244
|
+
|
|
245
|
+
async def emit_event(self, event: dict[str, Any]) -> None:
|
|
246
|
+
async for _ in self._request("events.emit", {"event": event}):
|
|
247
|
+
pass
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Length-prefixed JSON protocol for broker communication.
|
|
3
|
+
|
|
4
|
+
This module provides utilities for sending and receiving JSON messages
|
|
5
|
+
with a length prefix, avoiding the buffer size limitations of newline-delimited JSON.
|
|
6
|
+
|
|
7
|
+
Wire format:
|
|
8
|
+
<10-digit-decimal-length>\n<json-payload>
|
|
9
|
+
|
|
10
|
+
Example:
|
|
11
|
+
0000000123
|
|
12
|
+
{"id":"abc","method":"llm.chat","params":{...}}
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
from typing import Any, Dict, AsyncIterator
|
|
19
|
+
|
|
20
|
+
import anyio
|
|
21
|
+
from anyio.streams.buffered import BufferedByteReceiveStream
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
# Length prefix is exactly 10 decimal digits + newline
|
|
26
|
+
LENGTH_PREFIX_SIZE = 11 # "0000000123\n"
|
|
27
|
+
MAX_MESSAGE_SIZE = 100 * 1024 * 1024 # 100MB safety limit
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
async def write_message(writer: asyncio.StreamWriter, message: Dict[str, Any]) -> None:
|
|
31
|
+
"""
|
|
32
|
+
Write a JSON message with length prefix.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
writer: asyncio StreamWriter
|
|
36
|
+
message: Dictionary to encode as JSON
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If message is too large
|
|
40
|
+
"""
|
|
41
|
+
json_bytes = json.dumps(message).encode("utf-8")
|
|
42
|
+
length = len(json_bytes)
|
|
43
|
+
|
|
44
|
+
if length > MAX_MESSAGE_SIZE:
|
|
45
|
+
raise ValueError(f"Message size {length} exceeds maximum {MAX_MESSAGE_SIZE}")
|
|
46
|
+
|
|
47
|
+
# Write 10-digit length prefix + newline
|
|
48
|
+
length_prefix = f"{length:010d}\n".encode("ascii")
|
|
49
|
+
writer.write(length_prefix)
|
|
50
|
+
writer.write(json_bytes)
|
|
51
|
+
await writer.drain()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
async def read_message(reader: asyncio.StreamReader) -> Dict[str, Any]:
|
|
55
|
+
"""
|
|
56
|
+
Read a JSON message with length prefix.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
reader: asyncio StreamReader
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Parsed JSON message as dictionary
|
|
63
|
+
|
|
64
|
+
Raises:
|
|
65
|
+
EOFError: If connection closed
|
|
66
|
+
ValueError: If message is invalid or too large
|
|
67
|
+
"""
|
|
68
|
+
# Read exactly 11 bytes for length prefix
|
|
69
|
+
length_bytes = await reader.readexactly(LENGTH_PREFIX_SIZE)
|
|
70
|
+
|
|
71
|
+
if not length_bytes:
|
|
72
|
+
raise EOFError("Connection closed")
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
length_str = length_bytes[:10].decode("ascii")
|
|
76
|
+
length = int(length_str)
|
|
77
|
+
except (ValueError, UnicodeDecodeError) as e:
|
|
78
|
+
raise ValueError(f"Invalid length prefix: {length_bytes!r}") from e
|
|
79
|
+
|
|
80
|
+
if length > MAX_MESSAGE_SIZE:
|
|
81
|
+
raise ValueError(f"Message size {length} exceeds maximum {MAX_MESSAGE_SIZE}")
|
|
82
|
+
|
|
83
|
+
if length == 0:
|
|
84
|
+
raise ValueError("Zero-length message not allowed")
|
|
85
|
+
|
|
86
|
+
# Read exactly that many bytes for the JSON payload
|
|
87
|
+
json_bytes = await reader.readexactly(length)
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
message = json.loads(json_bytes.decode("utf-8"))
|
|
91
|
+
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
|
92
|
+
raise ValueError("Invalid JSON payload") from e
|
|
93
|
+
|
|
94
|
+
return message
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
async def read_messages(reader: asyncio.StreamReader) -> AsyncIterator[Dict[str, Any]]:
|
|
98
|
+
"""
|
|
99
|
+
Read a stream of length-prefixed JSON messages.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
reader: asyncio StreamReader
|
|
103
|
+
|
|
104
|
+
Yields:
|
|
105
|
+
Parsed JSON messages as dictionaries
|
|
106
|
+
|
|
107
|
+
Stops when connection is closed or error occurs.
|
|
108
|
+
"""
|
|
109
|
+
try:
|
|
110
|
+
while True:
|
|
111
|
+
message = await read_message(reader)
|
|
112
|
+
yield message
|
|
113
|
+
except EOFError:
|
|
114
|
+
return
|
|
115
|
+
except asyncio.IncompleteReadError:
|
|
116
|
+
return
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
# AnyIO-compatible versions for broker server
|
|
120
|
+
async def write_message_anyio(stream: anyio.abc.ByteStream, message: Dict[str, Any]) -> None:
|
|
121
|
+
"""
|
|
122
|
+
Write a JSON message with length prefix using AnyIO streams.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
stream: anyio ByteStream
|
|
126
|
+
message: Dictionary to encode as JSON
|
|
127
|
+
|
|
128
|
+
Raises:
|
|
129
|
+
ValueError: If message is too large
|
|
130
|
+
"""
|
|
131
|
+
json_bytes = json.dumps(message).encode("utf-8")
|
|
132
|
+
length = len(json_bytes)
|
|
133
|
+
|
|
134
|
+
if length > MAX_MESSAGE_SIZE:
|
|
135
|
+
raise ValueError(f"Message size {length} exceeds maximum {MAX_MESSAGE_SIZE}")
|
|
136
|
+
|
|
137
|
+
# Write 10-digit length prefix + newline
|
|
138
|
+
length_prefix = f"{length:010d}\n".encode("ascii")
|
|
139
|
+
await stream.send(length_prefix)
|
|
140
|
+
await stream.send(json_bytes)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
async def read_message_anyio(stream: BufferedByteReceiveStream) -> Dict[str, Any]:
|
|
144
|
+
"""
|
|
145
|
+
Read a JSON message with length prefix using AnyIO streams.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
stream: anyio BufferedByteReceiveStream
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Parsed JSON message as dictionary
|
|
152
|
+
|
|
153
|
+
Raises:
|
|
154
|
+
EOFError: If connection closed
|
|
155
|
+
ValueError: If message is invalid or too large
|
|
156
|
+
"""
|
|
157
|
+
# Read exactly 11 bytes for length prefix
|
|
158
|
+
length_bytes = await stream.receive_exactly(LENGTH_PREFIX_SIZE)
|
|
159
|
+
|
|
160
|
+
if not length_bytes:
|
|
161
|
+
raise EOFError("Connection closed")
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
length_str = length_bytes[:10].decode("ascii")
|
|
165
|
+
length = int(length_str)
|
|
166
|
+
except (ValueError, UnicodeDecodeError) as e:
|
|
167
|
+
raise ValueError(f"Invalid length prefix: {length_bytes!r}") from e
|
|
168
|
+
|
|
169
|
+
if length > MAX_MESSAGE_SIZE:
|
|
170
|
+
raise ValueError(f"Message size {length} exceeds maximum {MAX_MESSAGE_SIZE}")
|
|
171
|
+
|
|
172
|
+
if length == 0:
|
|
173
|
+
raise ValueError("Zero-length message not allowed")
|
|
174
|
+
|
|
175
|
+
# Read exactly that many bytes for the JSON payload
|
|
176
|
+
json_bytes = await stream.receive_exactly(length)
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
message = json.loads(json_bytes.decode("utf-8"))
|
|
180
|
+
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
|
181
|
+
raise ValueError("Invalid JSON payload") from e
|
|
182
|
+
|
|
183
|
+
return message
|