tactus 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tactus/__init__.py +1 -1
- tactus/adapters/broker_log.py +17 -14
- tactus/adapters/channels/__init__.py +17 -15
- tactus/adapters/channels/base.py +16 -7
- tactus/adapters/channels/broker.py +43 -13
- tactus/adapters/channels/cli.py +19 -15
- tactus/adapters/channels/host.py +15 -6
- tactus/adapters/channels/ipc.py +82 -31
- tactus/adapters/channels/sse.py +41 -23
- tactus/adapters/cli_hitl.py +19 -19
- tactus/adapters/cli_log.py +4 -4
- tactus/adapters/control_loop.py +138 -99
- tactus/adapters/cost_collector_log.py +9 -9
- tactus/adapters/file_storage.py +56 -52
- tactus/adapters/http_callback_log.py +23 -13
- tactus/adapters/ide_log.py +17 -9
- tactus/adapters/lua_tools.py +4 -5
- tactus/adapters/mcp.py +16 -19
- tactus/adapters/mcp_manager.py +46 -30
- tactus/adapters/memory.py +9 -9
- tactus/adapters/plugins.py +42 -42
- tactus/broker/client.py +75 -78
- tactus/broker/protocol.py +57 -57
- tactus/broker/server.py +252 -197
- tactus/cli/app.py +3 -1
- tactus/cli/control.py +2 -2
- tactus/core/config_manager.py +181 -135
- tactus/core/dependencies/registry.py +66 -48
- tactus/core/dsl_stubs.py +222 -163
- tactus/core/exceptions.py +10 -1
- tactus/core/execution_context.py +152 -112
- tactus/core/lua_sandbox.py +72 -64
- tactus/core/message_history_manager.py +138 -43
- tactus/core/mocking.py +41 -27
- tactus/core/output_validator.py +49 -44
- tactus/core/registry.py +94 -80
- tactus/core/runtime.py +211 -176
- tactus/core/template_resolver.py +16 -16
- tactus/core/yaml_parser.py +55 -45
- tactus/docs/extractor.py +7 -6
- tactus/ide/server.py +119 -78
- tactus/primitives/control.py +10 -6
- tactus/primitives/file.py +48 -46
- tactus/primitives/handles.py +47 -35
- tactus/primitives/host.py +29 -27
- tactus/primitives/human.py +154 -137
- tactus/primitives/json.py +22 -23
- tactus/primitives/log.py +26 -26
- tactus/primitives/message_history.py +285 -31
- tactus/primitives/model.py +15 -9
- tactus/primitives/procedure.py +86 -64
- tactus/primitives/procedure_callable.py +58 -51
- tactus/primitives/retry.py +31 -29
- tactus/primitives/session.py +42 -29
- tactus/primitives/state.py +54 -43
- tactus/primitives/step.py +9 -13
- tactus/primitives/system.py +34 -21
- tactus/primitives/tool.py +44 -31
- tactus/primitives/tool_handle.py +76 -54
- tactus/primitives/toolset.py +25 -22
- tactus/sandbox/config.py +4 -4
- tactus/sandbox/container_runner.py +161 -107
- tactus/sandbox/docker_manager.py +20 -20
- tactus/sandbox/entrypoint.py +16 -14
- tactus/sandbox/protocol.py +15 -15
- tactus/stdlib/classify/llm.py +1 -3
- tactus/stdlib/core/validation.py +0 -3
- tactus/testing/pydantic_eval_runner.py +1 -1
- tactus/utils/asyncio_helpers.py +27 -0
- tactus/utils/cost_calculator.py +7 -7
- tactus/utils/model_pricing.py +11 -12
- tactus/utils/safe_file_library.py +156 -132
- tactus/utils/safe_libraries.py +27 -27
- tactus/validation/error_listener.py +18 -5
- tactus/validation/semantic_visitor.py +392 -333
- tactus/validation/validator.py +89 -49
- {tactus-0.34.0.dist-info → tactus-0.35.0.dist-info}/METADATA +12 -3
- {tactus-0.34.0.dist-info → tactus-0.35.0.dist-info}/RECORD +81 -80
- {tactus-0.34.0.dist-info → tactus-0.35.0.dist-info}/WHEEL +0 -0
- {tactus-0.34.0.dist-info → tactus-0.35.0.dist-info}/entry_points.txt +0 -0
- {tactus-0.34.0.dist-info → tactus-0.35.0.dist-info}/licenses/LICENSE +0 -0
tactus/adapters/plugins.py
CHANGED
|
@@ -85,25 +85,25 @@ class PluginLoader:
|
|
|
85
85
|
all_tools = []
|
|
86
86
|
|
|
87
87
|
for path_str in paths:
|
|
88
|
-
|
|
88
|
+
resolved_path = Path(path_str).resolve()
|
|
89
89
|
|
|
90
|
-
if not
|
|
91
|
-
logger.warning(f"Tool path does not exist: {
|
|
90
|
+
if not resolved_path.exists():
|
|
91
|
+
logger.warning(f"Tool path does not exist: {resolved_path}")
|
|
92
92
|
continue
|
|
93
93
|
|
|
94
|
-
if
|
|
94
|
+
if resolved_path.is_file():
|
|
95
95
|
# Load tools from single file
|
|
96
|
-
if
|
|
97
|
-
tools = self._load_tools_from_file(
|
|
96
|
+
if resolved_path.suffix == ".py":
|
|
97
|
+
tools = self._load_tools_from_file(resolved_path)
|
|
98
98
|
all_tools.extend(tools)
|
|
99
99
|
else:
|
|
100
|
-
logger.warning(f"Skipping non-Python file: {
|
|
101
|
-
elif
|
|
100
|
+
logger.warning(f"Skipping non-Python file: {resolved_path}")
|
|
101
|
+
elif resolved_path.is_dir():
|
|
102
102
|
# Scan directory for Python files
|
|
103
|
-
tools = self._load_tools_from_directory(
|
|
103
|
+
tools = self._load_tools_from_directory(resolved_path)
|
|
104
104
|
all_tools.extend(tools)
|
|
105
105
|
else:
|
|
106
|
-
logger.warning(f"Path is neither file nor directory: {
|
|
106
|
+
logger.warning(f"Path is neither file nor directory: {resolved_path}")
|
|
107
107
|
|
|
108
108
|
logger.info(f"Loaded {len(all_tools)} tools from {len(paths)} path(s)")
|
|
109
109
|
return all_tools
|
|
@@ -121,23 +121,23 @@ class PluginLoader:
|
|
|
121
121
|
all_functions = []
|
|
122
122
|
|
|
123
123
|
for path_str in paths:
|
|
124
|
-
|
|
124
|
+
resolved_path = Path(path_str).resolve()
|
|
125
125
|
|
|
126
|
-
if not
|
|
127
|
-
logger.warning(f"Tool path does not exist: {
|
|
126
|
+
if not resolved_path.exists():
|
|
127
|
+
logger.warning(f"Tool path does not exist: {resolved_path}")
|
|
128
128
|
continue
|
|
129
129
|
|
|
130
|
-
if
|
|
131
|
-
if
|
|
132
|
-
functions = self._load_functions_from_file(
|
|
130
|
+
if resolved_path.is_file():
|
|
131
|
+
if resolved_path.suffix == ".py":
|
|
132
|
+
functions = self._load_functions_from_file(resolved_path)
|
|
133
133
|
all_functions.extend(functions)
|
|
134
134
|
else:
|
|
135
|
-
logger.warning(f"Skipping non-Python file: {
|
|
136
|
-
elif
|
|
137
|
-
functions = self._load_functions_from_directory(
|
|
135
|
+
logger.warning(f"Skipping non-Python file: {resolved_path}")
|
|
136
|
+
elif resolved_path.is_dir():
|
|
137
|
+
functions = self._load_functions_from_directory(resolved_path)
|
|
138
138
|
all_functions.extend(functions)
|
|
139
139
|
else:
|
|
140
|
-
logger.warning(f"Path is neither file nor directory: {
|
|
140
|
+
logger.warning(f"Path is neither file nor directory: {resolved_path}")
|
|
141
141
|
|
|
142
142
|
logger.debug(f"Loaded {len(all_functions)} function(s) from {len(paths)} path(s)")
|
|
143
143
|
return all_functions
|
|
@@ -200,8 +200,8 @@ class PluginLoader:
|
|
|
200
200
|
functions.append(obj)
|
|
201
201
|
logger.debug(f"Found function '{name}' in {file_path.name}")
|
|
202
202
|
|
|
203
|
-
except Exception as
|
|
204
|
-
logger.error(f"Failed to load functions from {file_path}: {
|
|
203
|
+
except Exception as error:
|
|
204
|
+
logger.error(f"Failed to load functions from {file_path}: {error}", exc_info=True)
|
|
205
205
|
|
|
206
206
|
return functions
|
|
207
207
|
|
|
@@ -216,14 +216,14 @@ class PluginLoader:
|
|
|
216
216
|
Async callback function for process_tool_call
|
|
217
217
|
"""
|
|
218
218
|
|
|
219
|
-
async def trace_tool_call(
|
|
219
|
+
async def trace_tool_call(execution_context, invoke_next, tool_name, tool_args):
|
|
220
220
|
"""Middleware to record tool calls in Tactus ToolPrimitive."""
|
|
221
221
|
logger.debug(
|
|
222
222
|
f"Toolset '{toolset_name}' calling tool '{tool_name}' with args: {tool_args}"
|
|
223
223
|
)
|
|
224
224
|
|
|
225
225
|
try:
|
|
226
|
-
result = await
|
|
226
|
+
result = await invoke_next(tool_name, tool_args)
|
|
227
227
|
|
|
228
228
|
# Record in ToolPrimitive if available
|
|
229
229
|
if self.tool_primitive:
|
|
@@ -232,11 +232,11 @@ class PluginLoader:
|
|
|
232
232
|
|
|
233
233
|
logger.debug(f"Tool '{tool_name}' completed successfully")
|
|
234
234
|
return result
|
|
235
|
-
except Exception as
|
|
236
|
-
logger.error(f"Tool '{tool_name}' failed: {
|
|
235
|
+
except Exception as error:
|
|
236
|
+
logger.error(f"Tool '{tool_name}' failed: {error}", exc_info=True)
|
|
237
237
|
# Still record the failed call
|
|
238
238
|
if self.tool_primitive:
|
|
239
|
-
error_msg = f"Error: {str(
|
|
239
|
+
error_msg = f"Error: {str(error)}"
|
|
240
240
|
self.tool_primitive.record_call(tool_name, tool_args, error_msg)
|
|
241
241
|
raise
|
|
242
242
|
|
|
@@ -308,8 +308,8 @@ class PluginLoader:
|
|
|
308
308
|
tools.append(tool)
|
|
309
309
|
logger.info(f"Loaded tool '{name}' from {file_path.name}")
|
|
310
310
|
|
|
311
|
-
except Exception as
|
|
312
|
-
logger.error(f"Failed to load tools from {file_path}: {
|
|
311
|
+
except Exception as error:
|
|
312
|
+
logger.error(f"Failed to load tools from {file_path}: {error}", exc_info=True)
|
|
313
313
|
|
|
314
314
|
return tools
|
|
315
315
|
|
|
@@ -352,8 +352,8 @@ class PluginLoader:
|
|
|
352
352
|
"""
|
|
353
353
|
try:
|
|
354
354
|
# Get function signature and docstring
|
|
355
|
-
|
|
356
|
-
|
|
355
|
+
signature = inspect.signature(func)
|
|
356
|
+
docstring = inspect.getdoc(func) or f"Tool: {name}"
|
|
357
357
|
|
|
358
358
|
# Check if function is async
|
|
359
359
|
is_async = inspect.iscoroutinefunction(func)
|
|
@@ -371,9 +371,9 @@ class PluginLoader:
|
|
|
371
371
|
self.tool_primitive.record_call(name, kwargs, str(result))
|
|
372
372
|
|
|
373
373
|
return result
|
|
374
|
-
except Exception as
|
|
375
|
-
logger.error(f"Tool '{name}' execution failed: {
|
|
376
|
-
error_msg = f"Error executing tool '{name}': {str(
|
|
374
|
+
except Exception as error:
|
|
375
|
+
logger.error(f"Tool '{name}' execution failed: {error}", exc_info=True)
|
|
376
|
+
error_msg = f"Error executing tool '{name}': {str(error)}"
|
|
377
377
|
|
|
378
378
|
# Record failed call
|
|
379
379
|
if self.tool_primitive:
|
|
@@ -393,9 +393,9 @@ class PluginLoader:
|
|
|
393
393
|
self.tool_primitive.record_call(name, kwargs, str(result))
|
|
394
394
|
|
|
395
395
|
return result
|
|
396
|
-
except Exception as
|
|
397
|
-
logger.error(f"Tool '{name}' execution failed: {
|
|
398
|
-
error_msg = f"Error executing tool '{name}': {str(
|
|
396
|
+
except Exception as error:
|
|
397
|
+
logger.error(f"Tool '{name}' execution failed: {error}", exc_info=True)
|
|
398
|
+
error_msg = f"Error executing tool '{name}': {str(error)}"
|
|
399
399
|
|
|
400
400
|
# Record failed call
|
|
401
401
|
if self.tool_primitive:
|
|
@@ -404,16 +404,16 @@ class PluginLoader:
|
|
|
404
404
|
raise
|
|
405
405
|
|
|
406
406
|
# Copy signature and docstring to wrapper
|
|
407
|
-
tool_wrapper.__signature__ =
|
|
408
|
-
tool_wrapper.__doc__ =
|
|
407
|
+
tool_wrapper.__signature__ = signature
|
|
408
|
+
tool_wrapper.__doc__ = docstring
|
|
409
409
|
tool_wrapper.__name__ = name
|
|
410
410
|
tool_wrapper.__annotations__ = func.__annotations__
|
|
411
411
|
|
|
412
412
|
# Create Pydantic AI Tool
|
|
413
|
-
tool = Tool(tool_wrapper, name=name, description=
|
|
413
|
+
tool = Tool(tool_wrapper, name=name, description=docstring)
|
|
414
414
|
|
|
415
415
|
return tool
|
|
416
416
|
|
|
417
|
-
except Exception as
|
|
418
|
-
logger.error(f"Failed to create tool from function '{name}': {
|
|
417
|
+
except Exception as error:
|
|
418
|
+
logger.error(f"Failed to create tool from function '{name}': {error}", exc_info=True)
|
|
419
419
|
return None
|
tactus/broker/client.py
CHANGED
|
@@ -31,12 +31,16 @@ def _json_dumps(obj: Any) -> str:
|
|
|
31
31
|
class _StdioBrokerTransport:
|
|
32
32
|
def __init__(self):
|
|
33
33
|
self._write_lock = threading.Lock()
|
|
34
|
-
self.
|
|
34
|
+
self._pending_requests: dict[
|
|
35
35
|
str, tuple[asyncio.AbstractEventLoop, asyncio.Queue[dict[str, Any]]]
|
|
36
36
|
] = {}
|
|
37
|
+
# Backward-compatible alias used in tests and older code paths.
|
|
38
|
+
self._pending = self._pending_requests
|
|
37
39
|
self._pending_lock = threading.Lock()
|
|
38
40
|
self._reader_thread: Optional[threading.Thread] = None
|
|
39
|
-
self.
|
|
41
|
+
self._shutdown_event = threading.Event()
|
|
42
|
+
# Backward-compatible alias used in tests and older code paths.
|
|
43
|
+
self._stop = self._shutdown_event
|
|
40
44
|
|
|
41
45
|
def _ensure_reader_thread(self) -> None:
|
|
42
46
|
if self._reader_thread is not None and self._reader_thread.is_alive():
|
|
@@ -50,33 +54,33 @@ class _StdioBrokerTransport:
|
|
|
50
54
|
self._reader_thread.start()
|
|
51
55
|
|
|
52
56
|
def _read_loop(self) -> None:
|
|
53
|
-
while not self.
|
|
54
|
-
|
|
55
|
-
if not
|
|
57
|
+
while not self._shutdown_event.is_set():
|
|
58
|
+
input_line = sys.stdin.buffer.readline()
|
|
59
|
+
if not input_line:
|
|
56
60
|
return
|
|
57
61
|
try:
|
|
58
|
-
|
|
62
|
+
event_payload = json.loads(input_line.decode("utf-8"))
|
|
59
63
|
except json.JSONDecodeError:
|
|
60
64
|
continue
|
|
61
65
|
|
|
62
|
-
|
|
63
|
-
if not isinstance(
|
|
66
|
+
request_id_value = event_payload.get("id")
|
|
67
|
+
if not isinstance(request_id_value, str):
|
|
64
68
|
continue
|
|
65
69
|
|
|
66
70
|
with self._pending_lock:
|
|
67
|
-
|
|
68
|
-
if
|
|
71
|
+
pending_request = self._pending_requests.get(request_id_value)
|
|
72
|
+
if pending_request is None:
|
|
69
73
|
continue
|
|
70
74
|
|
|
71
|
-
|
|
75
|
+
event_loop, response_queue = pending_request
|
|
72
76
|
try:
|
|
73
|
-
|
|
77
|
+
event_loop.call_soon_threadsafe(response_queue.put_nowait, event_payload)
|
|
74
78
|
except RuntimeError:
|
|
75
79
|
# Loop is closed or unavailable; ignore.
|
|
76
80
|
continue
|
|
77
81
|
|
|
78
82
|
async def aclose(self) -> None:
|
|
79
|
-
self.
|
|
83
|
+
self._shutdown_event.set()
|
|
80
84
|
thread = self._reader_thread
|
|
81
85
|
if thread is None or not thread.is_alive():
|
|
82
86
|
return
|
|
@@ -86,28 +90,28 @@ class _StdioBrokerTransport:
|
|
|
86
90
|
return
|
|
87
91
|
|
|
88
92
|
async def request(
|
|
89
|
-
self,
|
|
93
|
+
self, request_id: str, method: str, params: dict[str, Any]
|
|
90
94
|
) -> AsyncIterator[dict[str, Any]]:
|
|
91
95
|
self._ensure_reader_thread()
|
|
92
|
-
|
|
93
|
-
|
|
96
|
+
event_loop = asyncio.get_running_loop()
|
|
97
|
+
response_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
|
94
98
|
with self._pending_lock:
|
|
95
|
-
self.
|
|
99
|
+
self._pending_requests[request_id] = (event_loop, response_queue)
|
|
96
100
|
|
|
97
101
|
try:
|
|
98
|
-
|
|
102
|
+
request_payload = _json_dumps({"id": request_id, "method": method, "params": params})
|
|
99
103
|
with self._write_lock:
|
|
100
|
-
sys.stderr.write(f"{STDIO_REQUEST_PREFIX}{
|
|
104
|
+
sys.stderr.write(f"{STDIO_REQUEST_PREFIX}{request_payload}\n")
|
|
101
105
|
sys.stderr.flush()
|
|
102
106
|
|
|
103
107
|
while True:
|
|
104
|
-
|
|
105
|
-
yield
|
|
106
|
-
if
|
|
108
|
+
event_payload = await response_queue.get()
|
|
109
|
+
yield event_payload
|
|
110
|
+
if event_payload.get("event") in ("done", "error"):
|
|
107
111
|
return
|
|
108
112
|
finally:
|
|
109
113
|
with self._pending_lock:
|
|
110
|
-
self.
|
|
114
|
+
self._pending_requests.pop(request_id, None)
|
|
111
115
|
|
|
112
116
|
|
|
113
117
|
_STDIO_TRANSPORT = _StdioBrokerTransport()
|
|
@@ -129,59 +133,61 @@ class BrokerClient:
|
|
|
129
133
|
return cls(socket_path)
|
|
130
134
|
|
|
131
135
|
async def _request(self, method: str, params: dict[str, Any]) -> AsyncIterator[dict[str, Any]]:
|
|
132
|
-
|
|
136
|
+
request_id = uuid.uuid4().hex
|
|
133
137
|
|
|
134
138
|
if self.socket_path == STDIO_TRANSPORT_VALUE:
|
|
135
|
-
async for
|
|
139
|
+
async for event_payload in _STDIO_TRANSPORT.request(request_id, method, params):
|
|
136
140
|
# Responses are already correlated by req_id; add a defensive filter anyway.
|
|
137
|
-
if
|
|
138
|
-
yield
|
|
141
|
+
if event_payload.get("id") == request_id:
|
|
142
|
+
yield event_payload
|
|
139
143
|
return
|
|
140
144
|
|
|
141
145
|
if self.socket_path.startswith(("tcp://", "tls://")):
|
|
142
146
|
use_tls = self.socket_path.startswith("tls://")
|
|
143
|
-
|
|
144
|
-
if "/" in
|
|
145
|
-
|
|
146
|
-
if ":" not in
|
|
147
|
+
host_and_port = self.socket_path.split("://", 1)[1]
|
|
148
|
+
if "/" in host_and_port:
|
|
149
|
+
host_and_port = host_and_port.split("/", 1)[0]
|
|
150
|
+
if ":" not in host_and_port:
|
|
147
151
|
raise ValueError(
|
|
148
|
-
|
|
152
|
+
"Invalid broker endpoint: "
|
|
153
|
+
f"{self.socket_path}. Expected tcp://host:port or tls://host:port"
|
|
149
154
|
)
|
|
150
|
-
host,
|
|
155
|
+
host, port_text = host_and_port.rsplit(":", 1)
|
|
151
156
|
try:
|
|
152
|
-
port = int(
|
|
153
|
-
except ValueError as
|
|
154
|
-
raise ValueError(f"Invalid broker port in endpoint: {self.socket_path}") from
|
|
157
|
+
port = int(port_text)
|
|
158
|
+
except ValueError as error:
|
|
159
|
+
raise ValueError(f"Invalid broker port in endpoint: {self.socket_path}") from error
|
|
155
160
|
|
|
156
|
-
|
|
161
|
+
ssl_context: ssl.SSLContext | None = None
|
|
157
162
|
if use_tls:
|
|
158
|
-
|
|
163
|
+
ssl_context = ssl.create_default_context()
|
|
159
164
|
cafile = os.environ.get("TACTUS_BROKER_TLS_CA_FILE")
|
|
160
165
|
if cafile:
|
|
161
|
-
|
|
166
|
+
ssl_context.load_verify_locations(cafile=cafile)
|
|
162
167
|
|
|
163
168
|
if os.environ.get("TACTUS_BROKER_TLS_INSECURE") in ("1", "true", "yes"):
|
|
164
|
-
|
|
165
|
-
|
|
169
|
+
ssl_context.check_hostname = False
|
|
170
|
+
ssl_context.verify_mode = ssl.CERT_NONE
|
|
166
171
|
|
|
167
|
-
reader, writer = await asyncio.open_connection(host, port, ssl=
|
|
172
|
+
reader, writer = await asyncio.open_connection(host, port, ssl=ssl_context)
|
|
168
173
|
logger.info(
|
|
169
|
-
|
|
174
|
+
"[BROKER_CLIENT] Writing message to broker, params keys: %s",
|
|
175
|
+
list(params.keys()),
|
|
170
176
|
)
|
|
171
177
|
try:
|
|
172
|
-
await write_message(writer, {"id":
|
|
173
|
-
except TypeError as
|
|
174
|
-
logger.error(
|
|
175
|
-
logger.error(
|
|
178
|
+
await write_message(writer, {"id": request_id, "method": method, "params": params})
|
|
179
|
+
except TypeError as error:
|
|
180
|
+
logger.error("[BROKER_CLIENT] JSON serialization error: %s", error)
|
|
181
|
+
logger.error("[BROKER_CLIENT] Params: %s", params)
|
|
176
182
|
raise
|
|
177
183
|
|
|
178
184
|
try:
|
|
179
185
|
while True:
|
|
180
|
-
|
|
181
|
-
if
|
|
186
|
+
event_payload = await read_message(reader)
|
|
187
|
+
if event_payload.get("id") != request_id:
|
|
182
188
|
continue
|
|
183
|
-
yield
|
|
184
|
-
if
|
|
189
|
+
yield event_payload
|
|
190
|
+
if event_payload.get("event") in ("done", "error"):
|
|
185
191
|
return
|
|
186
192
|
finally:
|
|
187
193
|
try:
|
|
@@ -191,16 +197,16 @@ class BrokerClient:
|
|
|
191
197
|
pass
|
|
192
198
|
|
|
193
199
|
reader, writer = await asyncio.open_unix_connection(self.socket_path)
|
|
194
|
-
await write_message(writer, {"id":
|
|
200
|
+
await write_message(writer, {"id": request_id, "method": method, "params": params})
|
|
195
201
|
|
|
196
202
|
try:
|
|
197
203
|
while True:
|
|
198
|
-
|
|
204
|
+
event_payload = await read_message(reader)
|
|
199
205
|
# Ignore unrelated messages (defensive; current server is 1-req/conn).
|
|
200
|
-
if
|
|
206
|
+
if event_payload.get("id") != request_id:
|
|
201
207
|
continue
|
|
202
|
-
yield
|
|
203
|
-
if
|
|
208
|
+
yield event_payload
|
|
209
|
+
if event_payload.get("event") in ("done", "error"):
|
|
204
210
|
return
|
|
205
211
|
finally:
|
|
206
212
|
try:
|
|
@@ -221,34 +227,25 @@ class BrokerClient:
|
|
|
221
227
|
tools: Optional[list[dict[str, Any]]] = None,
|
|
222
228
|
tool_choice: Optional[str] = None,
|
|
223
229
|
) -> AsyncIterator[dict[str, Any]]:
|
|
224
|
-
|
|
230
|
+
request_params: dict[str, Any] = {
|
|
225
231
|
"provider": provider,
|
|
226
232
|
"model": model,
|
|
227
233
|
"messages": messages,
|
|
228
234
|
"stream": stream,
|
|
229
235
|
}
|
|
230
236
|
if temperature is not None:
|
|
231
|
-
|
|
237
|
+
request_params["temperature"] = temperature
|
|
232
238
|
if max_tokens is not None:
|
|
233
|
-
|
|
239
|
+
request_params["max_tokens"] = max_tokens
|
|
234
240
|
if tools is not None:
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
logger = logging.getLogger(__name__)
|
|
239
|
-
logger.info(f"[BROKER_CLIENT] Adding {len(tools)} tools to params")
|
|
241
|
+
request_params["tools"] = tools
|
|
242
|
+
logger.info("[BROKER_CLIENT] Adding %s tools to params", len(tools))
|
|
240
243
|
else:
|
|
241
|
-
import logging
|
|
242
|
-
|
|
243
|
-
logger = logging.getLogger(__name__)
|
|
244
244
|
logger.warning("[BROKER_CLIENT] No tools to add to params")
|
|
245
245
|
if tool_choice is not None:
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
logger = logging.getLogger(__name__)
|
|
250
|
-
logger.info(f"[BROKER_CLIENT] Adding tool_choice={tool_choice} to params")
|
|
251
|
-
return self._request("llm.chat", params)
|
|
246
|
+
request_params["tool_choice"] = tool_choice
|
|
247
|
+
logger.info("[BROKER_CLIENT] Adding tool_choice=%s to params", tool_choice)
|
|
248
|
+
return self._request("llm.chat", request_params)
|
|
252
249
|
|
|
253
250
|
async def call_tool(self, *, name: str, args: dict[str, Any]) -> Any:
|
|
254
251
|
"""
|
|
@@ -261,14 +258,14 @@ class BrokerClient:
|
|
|
261
258
|
if not isinstance(args, dict):
|
|
262
259
|
raise ValueError("tool args must be an object")
|
|
263
260
|
|
|
264
|
-
async for
|
|
265
|
-
event_type =
|
|
261
|
+
async for event_payload in self._request("tool.call", {"name": name, "args": args}):
|
|
262
|
+
event_type = event_payload.get("event")
|
|
266
263
|
if event_type == "done":
|
|
267
|
-
data =
|
|
264
|
+
data = event_payload.get("data") or {}
|
|
268
265
|
return data.get("result")
|
|
269
266
|
if event_type == "error":
|
|
270
|
-
|
|
271
|
-
raise RuntimeError(
|
|
267
|
+
error_payload = event_payload.get("error") or {}
|
|
268
|
+
raise RuntimeError(error_payload.get("message") or "Broker tool error")
|
|
272
269
|
|
|
273
270
|
raise RuntimeError("Broker tool call ended without a response")
|
|
274
271
|
|
tactus/broker/protocol.py
CHANGED
|
@@ -12,10 +12,10 @@ Example:
|
|
|
12
12
|
{"id":"abc","method":"llm.chat","params":{...}}
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
-
import json
|
|
16
15
|
import asyncio
|
|
17
16
|
import logging
|
|
18
|
-
|
|
17
|
+
import json
|
|
18
|
+
from typing import Any, AsyncIterator
|
|
19
19
|
|
|
20
20
|
import anyio
|
|
21
21
|
from anyio.streams.buffered import BufferedByteReceiveStream
|
|
@@ -27,7 +27,33 @@ LENGTH_PREFIX_SIZE = 11 # "0000000123\n"
|
|
|
27
27
|
MAX_MESSAGE_SIZE = 100 * 1024 * 1024 # 100MB safety limit
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
|
|
30
|
+
def _parse_length_prefix(length_prefix_bytes: bytes) -> int:
|
|
31
|
+
try:
|
|
32
|
+
length_text = length_prefix_bytes[:10].decode("ascii")
|
|
33
|
+
payload_length = int(length_text)
|
|
34
|
+
except (ValueError, UnicodeDecodeError) as error:
|
|
35
|
+
raise ValueError(f"Invalid length prefix: {length_prefix_bytes!r}") from error
|
|
36
|
+
|
|
37
|
+
if payload_length > MAX_MESSAGE_SIZE:
|
|
38
|
+
raise ValueError(f"Message size {payload_length} exceeds maximum {MAX_MESSAGE_SIZE}")
|
|
39
|
+
|
|
40
|
+
if payload_length == 0:
|
|
41
|
+
raise ValueError("Zero-length message not allowed")
|
|
42
|
+
|
|
43
|
+
return payload_length
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _serialize_json_payload(message: dict[str, Any]) -> bytes:
|
|
47
|
+
json_payload_bytes = json.dumps(message).encode("utf-8")
|
|
48
|
+
payload_length = len(json_payload_bytes)
|
|
49
|
+
|
|
50
|
+
if payload_length > MAX_MESSAGE_SIZE:
|
|
51
|
+
raise ValueError(f"Message size {payload_length} exceeds maximum {MAX_MESSAGE_SIZE}")
|
|
52
|
+
|
|
53
|
+
return json_payload_bytes
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
async def write_message(writer: asyncio.StreamWriter, message: dict[str, Any]) -> None:
|
|
31
57
|
"""
|
|
32
58
|
Write a JSON message with length prefix.
|
|
33
59
|
|
|
@@ -38,20 +64,17 @@ async def write_message(writer: asyncio.StreamWriter, message: Dict[str, Any]) -
|
|
|
38
64
|
Raises:
|
|
39
65
|
ValueError: If message is too large
|
|
40
66
|
"""
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
if length > MAX_MESSAGE_SIZE:
|
|
45
|
-
raise ValueError(f"Message size {length} exceeds maximum {MAX_MESSAGE_SIZE}")
|
|
67
|
+
json_payload_bytes = _serialize_json_payload(message)
|
|
68
|
+
payload_length = len(json_payload_bytes)
|
|
46
69
|
|
|
47
70
|
# Write 10-digit length prefix + newline
|
|
48
|
-
length_prefix = f"{
|
|
71
|
+
length_prefix = f"{payload_length:010d}\n".encode("ascii")
|
|
49
72
|
writer.write(length_prefix)
|
|
50
|
-
writer.write(
|
|
73
|
+
writer.write(json_payload_bytes)
|
|
51
74
|
await writer.drain()
|
|
52
75
|
|
|
53
76
|
|
|
54
|
-
async def read_message(reader: asyncio.StreamReader) ->
|
|
77
|
+
async def read_message(reader: asyncio.StreamReader) -> dict[str, Any]:
|
|
55
78
|
"""
|
|
56
79
|
Read a JSON message with length prefix.
|
|
57
80
|
|
|
@@ -66,35 +89,25 @@ async def read_message(reader: asyncio.StreamReader) -> Dict[str, Any]:
|
|
|
66
89
|
ValueError: If message is invalid or too large
|
|
67
90
|
"""
|
|
68
91
|
# Read exactly 11 bytes for length prefix
|
|
69
|
-
|
|
92
|
+
length_prefix_bytes = await reader.readexactly(LENGTH_PREFIX_SIZE)
|
|
70
93
|
|
|
71
|
-
if not
|
|
94
|
+
if not length_prefix_bytes:
|
|
72
95
|
raise EOFError("Connection closed")
|
|
73
96
|
|
|
74
|
-
|
|
75
|
-
length_str = length_bytes[:10].decode("ascii")
|
|
76
|
-
length = int(length_str)
|
|
77
|
-
except (ValueError, UnicodeDecodeError) as e:
|
|
78
|
-
raise ValueError(f"Invalid length prefix: {length_bytes!r}") from e
|
|
79
|
-
|
|
80
|
-
if length > MAX_MESSAGE_SIZE:
|
|
81
|
-
raise ValueError(f"Message size {length} exceeds maximum {MAX_MESSAGE_SIZE}")
|
|
82
|
-
|
|
83
|
-
if length == 0:
|
|
84
|
-
raise ValueError("Zero-length message not allowed")
|
|
97
|
+
payload_length = _parse_length_prefix(length_prefix_bytes)
|
|
85
98
|
|
|
86
99
|
# Read exactly that many bytes for the JSON payload
|
|
87
|
-
|
|
100
|
+
json_payload_bytes = await reader.readexactly(payload_length)
|
|
88
101
|
|
|
89
102
|
try:
|
|
90
|
-
message = json.loads(
|
|
91
|
-
except (json.JSONDecodeError, UnicodeDecodeError) as
|
|
92
|
-
raise ValueError("Invalid JSON payload") from
|
|
103
|
+
message = json.loads(json_payload_bytes.decode("utf-8"))
|
|
104
|
+
except (json.JSONDecodeError, UnicodeDecodeError) as error:
|
|
105
|
+
raise ValueError("Invalid JSON payload") from error
|
|
93
106
|
|
|
94
107
|
return message
|
|
95
108
|
|
|
96
109
|
|
|
97
|
-
async def read_messages(reader: asyncio.StreamReader) -> AsyncIterator[
|
|
110
|
+
async def read_messages(reader: asyncio.StreamReader) -> AsyncIterator[dict[str, Any]]:
|
|
98
111
|
"""
|
|
99
112
|
Read a stream of length-prefixed JSON messages.
|
|
100
113
|
|
|
@@ -110,14 +123,14 @@ async def read_messages(reader: asyncio.StreamReader) -> AsyncIterator[Dict[str,
|
|
|
110
123
|
while True:
|
|
111
124
|
message = await read_message(reader)
|
|
112
125
|
yield message
|
|
113
|
-
except EOFError:
|
|
114
|
-
return
|
|
115
126
|
except asyncio.IncompleteReadError:
|
|
116
127
|
return
|
|
128
|
+
except EOFError:
|
|
129
|
+
return
|
|
117
130
|
|
|
118
131
|
|
|
119
132
|
# AnyIO-compatible versions for broker server
|
|
120
|
-
async def write_message_anyio(stream: anyio.abc.ByteStream, message:
|
|
133
|
+
async def write_message_anyio(stream: anyio.abc.ByteStream, message: dict[str, Any]) -> None:
|
|
121
134
|
"""
|
|
122
135
|
Write a JSON message with length prefix using AnyIO streams.
|
|
123
136
|
|
|
@@ -128,19 +141,16 @@ async def write_message_anyio(stream: anyio.abc.ByteStream, message: Dict[str, A
|
|
|
128
141
|
Raises:
|
|
129
142
|
ValueError: If message is too large
|
|
130
143
|
"""
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
if length > MAX_MESSAGE_SIZE:
|
|
135
|
-
raise ValueError(f"Message size {length} exceeds maximum {MAX_MESSAGE_SIZE}")
|
|
144
|
+
json_payload_bytes = _serialize_json_payload(message)
|
|
145
|
+
payload_length = len(json_payload_bytes)
|
|
136
146
|
|
|
137
147
|
# Write 10-digit length prefix + newline
|
|
138
|
-
length_prefix = f"{
|
|
148
|
+
length_prefix = f"{payload_length:010d}\n".encode("ascii")
|
|
139
149
|
await stream.send(length_prefix)
|
|
140
|
-
await stream.send(
|
|
150
|
+
await stream.send(json_payload_bytes)
|
|
141
151
|
|
|
142
152
|
|
|
143
|
-
async def read_message_anyio(stream: BufferedByteReceiveStream) ->
|
|
153
|
+
async def read_message_anyio(stream: BufferedByteReceiveStream) -> dict[str, Any]:
|
|
144
154
|
"""
|
|
145
155
|
Read a JSON message with length prefix using AnyIO streams.
|
|
146
156
|
|
|
@@ -155,29 +165,19 @@ async def read_message_anyio(stream: BufferedByteReceiveStream) -> Dict[str, Any
|
|
|
155
165
|
ValueError: If message is invalid or too large
|
|
156
166
|
"""
|
|
157
167
|
# Read exactly 11 bytes for length prefix
|
|
158
|
-
|
|
168
|
+
length_prefix_bytes = await stream.receive_exactly(LENGTH_PREFIX_SIZE)
|
|
159
169
|
|
|
160
|
-
if not
|
|
170
|
+
if not length_prefix_bytes:
|
|
161
171
|
raise EOFError("Connection closed")
|
|
162
172
|
|
|
163
|
-
|
|
164
|
-
length_str = length_bytes[:10].decode("ascii")
|
|
165
|
-
length = int(length_str)
|
|
166
|
-
except (ValueError, UnicodeDecodeError) as e:
|
|
167
|
-
raise ValueError(f"Invalid length prefix: {length_bytes!r}") from e
|
|
168
|
-
|
|
169
|
-
if length > MAX_MESSAGE_SIZE:
|
|
170
|
-
raise ValueError(f"Message size {length} exceeds maximum {MAX_MESSAGE_SIZE}")
|
|
171
|
-
|
|
172
|
-
if length == 0:
|
|
173
|
-
raise ValueError("Zero-length message not allowed")
|
|
173
|
+
payload_length = _parse_length_prefix(length_prefix_bytes)
|
|
174
174
|
|
|
175
175
|
# Read exactly that many bytes for the JSON payload
|
|
176
|
-
|
|
176
|
+
json_payload_bytes = await stream.receive_exactly(payload_length)
|
|
177
177
|
|
|
178
178
|
try:
|
|
179
|
-
message = json.loads(
|
|
180
|
-
except (json.JSONDecodeError, UnicodeDecodeError) as
|
|
181
|
-
raise ValueError("Invalid JSON payload") from
|
|
179
|
+
message = json.loads(json_payload_bytes.decode("utf-8"))
|
|
180
|
+
except (json.JSONDecodeError, UnicodeDecodeError) as error:
|
|
181
|
+
raise ValueError("Invalid JSON payload") from error
|
|
182
182
|
|
|
183
183
|
return message
|