tactus 0.34.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. tactus/__init__.py +1 -1
  2. tactus/adapters/broker_log.py +17 -14
  3. tactus/adapters/channels/__init__.py +17 -15
  4. tactus/adapters/channels/base.py +16 -7
  5. tactus/adapters/channels/broker.py +43 -13
  6. tactus/adapters/channels/cli.py +19 -15
  7. tactus/adapters/channels/host.py +15 -6
  8. tactus/adapters/channels/ipc.py +82 -31
  9. tactus/adapters/channels/sse.py +41 -23
  10. tactus/adapters/cli_hitl.py +19 -19
  11. tactus/adapters/cli_log.py +4 -4
  12. tactus/adapters/control_loop.py +138 -99
  13. tactus/adapters/cost_collector_log.py +9 -9
  14. tactus/adapters/file_storage.py +56 -52
  15. tactus/adapters/http_callback_log.py +23 -13
  16. tactus/adapters/ide_log.py +17 -9
  17. tactus/adapters/lua_tools.py +4 -5
  18. tactus/adapters/mcp.py +16 -19
  19. tactus/adapters/mcp_manager.py +46 -30
  20. tactus/adapters/memory.py +9 -9
  21. tactus/adapters/plugins.py +42 -42
  22. tactus/broker/client.py +75 -78
  23. tactus/broker/protocol.py +57 -57
  24. tactus/broker/server.py +252 -197
  25. tactus/cli/app.py +3 -1
  26. tactus/cli/control.py +2 -2
  27. tactus/core/config_manager.py +181 -135
  28. tactus/core/dependencies/registry.py +66 -48
  29. tactus/core/dsl_stubs.py +222 -163
  30. tactus/core/exceptions.py +10 -1
  31. tactus/core/execution_context.py +152 -112
  32. tactus/core/lua_sandbox.py +72 -64
  33. tactus/core/message_history_manager.py +138 -43
  34. tactus/core/mocking.py +41 -27
  35. tactus/core/output_validator.py +49 -44
  36. tactus/core/registry.py +94 -80
  37. tactus/core/runtime.py +211 -176
  38. tactus/core/template_resolver.py +16 -16
  39. tactus/core/yaml_parser.py +55 -45
  40. tactus/docs/extractor.py +7 -6
  41. tactus/ide/server.py +119 -78
  42. tactus/primitives/control.py +10 -6
  43. tactus/primitives/file.py +48 -46
  44. tactus/primitives/handles.py +47 -35
  45. tactus/primitives/host.py +29 -27
  46. tactus/primitives/human.py +154 -137
  47. tactus/primitives/json.py +22 -23
  48. tactus/primitives/log.py +26 -26
  49. tactus/primitives/message_history.py +285 -31
  50. tactus/primitives/model.py +15 -9
  51. tactus/primitives/procedure.py +86 -64
  52. tactus/primitives/procedure_callable.py +58 -51
  53. tactus/primitives/retry.py +31 -29
  54. tactus/primitives/session.py +42 -29
  55. tactus/primitives/state.py +54 -43
  56. tactus/primitives/step.py +9 -13
  57. tactus/primitives/system.py +34 -21
  58. tactus/primitives/tool.py +44 -31
  59. tactus/primitives/tool_handle.py +76 -54
  60. tactus/primitives/toolset.py +25 -22
  61. tactus/sandbox/config.py +4 -4
  62. tactus/sandbox/container_runner.py +161 -107
  63. tactus/sandbox/docker_manager.py +20 -20
  64. tactus/sandbox/entrypoint.py +16 -14
  65. tactus/sandbox/protocol.py +15 -15
  66. tactus/stdlib/classify/llm.py +1 -3
  67. tactus/stdlib/core/validation.py +0 -3
  68. tactus/testing/pydantic_eval_runner.py +1 -1
  69. tactus/utils/asyncio_helpers.py +27 -0
  70. tactus/utils/cost_calculator.py +7 -7
  71. tactus/utils/model_pricing.py +11 -12
  72. tactus/utils/safe_file_library.py +156 -132
  73. tactus/utils/safe_libraries.py +27 -27
  74. tactus/validation/error_listener.py +18 -5
  75. tactus/validation/semantic_visitor.py +392 -333
  76. tactus/validation/validator.py +89 -49
  77. {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/METADATA +12 -3
  78. {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/RECORD +81 -80
  79. {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/WHEEL +0 -0
  80. {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/entry_points.txt +0 -0
  81. {tactus-0.34.1.dist-info → tactus-0.35.0.dist-info}/licenses/LICENSE +0 -0
@@ -85,25 +85,25 @@ class PluginLoader:
85
85
  all_tools = []
86
86
 
87
87
  for path_str in paths:
88
- path = Path(path_str).resolve()
88
+ resolved_path = Path(path_str).resolve()
89
89
 
90
- if not path.exists():
91
- logger.warning(f"Tool path does not exist: {path}")
90
+ if not resolved_path.exists():
91
+ logger.warning(f"Tool path does not exist: {resolved_path}")
92
92
  continue
93
93
 
94
- if path.is_file():
94
+ if resolved_path.is_file():
95
95
  # Load tools from single file
96
- if path.suffix == ".py":
97
- tools = self._load_tools_from_file(path)
96
+ if resolved_path.suffix == ".py":
97
+ tools = self._load_tools_from_file(resolved_path)
98
98
  all_tools.extend(tools)
99
99
  else:
100
- logger.warning(f"Skipping non-Python file: {path}")
101
- elif path.is_dir():
100
+ logger.warning(f"Skipping non-Python file: {resolved_path}")
101
+ elif resolved_path.is_dir():
102
102
  # Scan directory for Python files
103
- tools = self._load_tools_from_directory(path)
103
+ tools = self._load_tools_from_directory(resolved_path)
104
104
  all_tools.extend(tools)
105
105
  else:
106
- logger.warning(f"Path is neither file nor directory: {path}")
106
+ logger.warning(f"Path is neither file nor directory: {resolved_path}")
107
107
 
108
108
  logger.info(f"Loaded {len(all_tools)} tools from {len(paths)} path(s)")
109
109
  return all_tools
@@ -121,23 +121,23 @@ class PluginLoader:
121
121
  all_functions = []
122
122
 
123
123
  for path_str in paths:
124
- path = Path(path_str).resolve()
124
+ resolved_path = Path(path_str).resolve()
125
125
 
126
- if not path.exists():
127
- logger.warning(f"Tool path does not exist: {path}")
126
+ if not resolved_path.exists():
127
+ logger.warning(f"Tool path does not exist: {resolved_path}")
128
128
  continue
129
129
 
130
- if path.is_file():
131
- if path.suffix == ".py":
132
- functions = self._load_functions_from_file(path)
130
+ if resolved_path.is_file():
131
+ if resolved_path.suffix == ".py":
132
+ functions = self._load_functions_from_file(resolved_path)
133
133
  all_functions.extend(functions)
134
134
  else:
135
- logger.warning(f"Skipping non-Python file: {path}")
136
- elif path.is_dir():
137
- functions = self._load_functions_from_directory(path)
135
+ logger.warning(f"Skipping non-Python file: {resolved_path}")
136
+ elif resolved_path.is_dir():
137
+ functions = self._load_functions_from_directory(resolved_path)
138
138
  all_functions.extend(functions)
139
139
  else:
140
- logger.warning(f"Path is neither file nor directory: {path}")
140
+ logger.warning(f"Path is neither file nor directory: {resolved_path}")
141
141
 
142
142
  logger.debug(f"Loaded {len(all_functions)} function(s) from {len(paths)} path(s)")
143
143
  return all_functions
@@ -200,8 +200,8 @@ class PluginLoader:
200
200
  functions.append(obj)
201
201
  logger.debug(f"Found function '{name}' in {file_path.name}")
202
202
 
203
- except Exception as e:
204
- logger.error(f"Failed to load functions from {file_path}: {e}", exc_info=True)
203
+ except Exception as error:
204
+ logger.error(f"Failed to load functions from {file_path}: {error}", exc_info=True)
205
205
 
206
206
  return functions
207
207
 
@@ -216,14 +216,14 @@ class PluginLoader:
216
216
  Async callback function for process_tool_call
217
217
  """
218
218
 
219
- async def trace_tool_call(ctx, next_call, tool_name, tool_args):
219
+ async def trace_tool_call(execution_context, invoke_next, tool_name, tool_args):
220
220
  """Middleware to record tool calls in Tactus ToolPrimitive."""
221
221
  logger.debug(
222
222
  f"Toolset '{toolset_name}' calling tool '{tool_name}' with args: {tool_args}"
223
223
  )
224
224
 
225
225
  try:
226
- result = await next_call(tool_name, tool_args)
226
+ result = await invoke_next(tool_name, tool_args)
227
227
 
228
228
  # Record in ToolPrimitive if available
229
229
  if self.tool_primitive:
@@ -232,11 +232,11 @@ class PluginLoader:
232
232
 
233
233
  logger.debug(f"Tool '{tool_name}' completed successfully")
234
234
  return result
235
- except Exception as e:
236
- logger.error(f"Tool '{tool_name}' failed: {e}", exc_info=True)
235
+ except Exception as error:
236
+ logger.error(f"Tool '{tool_name}' failed: {error}", exc_info=True)
237
237
  # Still record the failed call
238
238
  if self.tool_primitive:
239
- error_msg = f"Error: {str(e)}"
239
+ error_msg = f"Error: {str(error)}"
240
240
  self.tool_primitive.record_call(tool_name, tool_args, error_msg)
241
241
  raise
242
242
 
@@ -308,8 +308,8 @@ class PluginLoader:
308
308
  tools.append(tool)
309
309
  logger.info(f"Loaded tool '{name}' from {file_path.name}")
310
310
 
311
- except Exception as e:
312
- logger.error(f"Failed to load tools from {file_path}: {e}", exc_info=True)
311
+ except Exception as error:
312
+ logger.error(f"Failed to load tools from {file_path}: {error}", exc_info=True)
313
313
 
314
314
  return tools
315
315
 
@@ -352,8 +352,8 @@ class PluginLoader:
352
352
  """
353
353
  try:
354
354
  # Get function signature and docstring
355
- sig = inspect.signature(func)
356
- doc = inspect.getdoc(func) or f"Tool: {name}"
355
+ signature = inspect.signature(func)
356
+ docstring = inspect.getdoc(func) or f"Tool: {name}"
357
357
 
358
358
  # Check if function is async
359
359
  is_async = inspect.iscoroutinefunction(func)
@@ -371,9 +371,9 @@ class PluginLoader:
371
371
  self.tool_primitive.record_call(name, kwargs, str(result))
372
372
 
373
373
  return result
374
- except Exception as e:
375
- logger.error(f"Tool '{name}' execution failed: {e}", exc_info=True)
376
- error_msg = f"Error executing tool '{name}': {str(e)}"
374
+ except Exception as error:
375
+ logger.error(f"Tool '{name}' execution failed: {error}", exc_info=True)
376
+ error_msg = f"Error executing tool '{name}': {str(error)}"
377
377
 
378
378
  # Record failed call
379
379
  if self.tool_primitive:
@@ -393,9 +393,9 @@ class PluginLoader:
393
393
  self.tool_primitive.record_call(name, kwargs, str(result))
394
394
 
395
395
  return result
396
- except Exception as e:
397
- logger.error(f"Tool '{name}' execution failed: {e}", exc_info=True)
398
- error_msg = f"Error executing tool '{name}': {str(e)}"
396
+ except Exception as error:
397
+ logger.error(f"Tool '{name}' execution failed: {error}", exc_info=True)
398
+ error_msg = f"Error executing tool '{name}': {str(error)}"
399
399
 
400
400
  # Record failed call
401
401
  if self.tool_primitive:
@@ -404,16 +404,16 @@ class PluginLoader:
404
404
  raise
405
405
 
406
406
  # Copy signature and docstring to wrapper
407
- tool_wrapper.__signature__ = sig
408
- tool_wrapper.__doc__ = doc
407
+ tool_wrapper.__signature__ = signature
408
+ tool_wrapper.__doc__ = docstring
409
409
  tool_wrapper.__name__ = name
410
410
  tool_wrapper.__annotations__ = func.__annotations__
411
411
 
412
412
  # Create Pydantic AI Tool
413
- tool = Tool(tool_wrapper, name=name, description=doc)
413
+ tool = Tool(tool_wrapper, name=name, description=docstring)
414
414
 
415
415
  return tool
416
416
 
417
- except Exception as e:
418
- logger.error(f"Failed to create tool from function '{name}': {e}", exc_info=True)
417
+ except Exception as error:
418
+ logger.error(f"Failed to create tool from function '{name}': {error}", exc_info=True)
419
419
  return None
tactus/broker/client.py CHANGED
@@ -31,12 +31,16 @@ def _json_dumps(obj: Any) -> str:
31
31
  class _StdioBrokerTransport:
32
32
  def __init__(self):
33
33
  self._write_lock = threading.Lock()
34
- self._pending: dict[
34
+ self._pending_requests: dict[
35
35
  str, tuple[asyncio.AbstractEventLoop, asyncio.Queue[dict[str, Any]]]
36
36
  ] = {}
37
+ # Backward-compatible alias used in tests and older code paths.
38
+ self._pending = self._pending_requests
37
39
  self._pending_lock = threading.Lock()
38
40
  self._reader_thread: Optional[threading.Thread] = None
39
- self._stop = threading.Event()
41
+ self._shutdown_event = threading.Event()
42
+ # Backward-compatible alias used in tests and older code paths.
43
+ self._stop = self._shutdown_event
40
44
 
41
45
  def _ensure_reader_thread(self) -> None:
42
46
  if self._reader_thread is not None and self._reader_thread.is_alive():
@@ -50,33 +54,33 @@ class _StdioBrokerTransport:
50
54
  self._reader_thread.start()
51
55
 
52
56
  def _read_loop(self) -> None:
53
- while not self._stop.is_set():
54
- line = sys.stdin.buffer.readline()
55
- if not line:
57
+ while not self._shutdown_event.is_set():
58
+ input_line = sys.stdin.buffer.readline()
59
+ if not input_line:
56
60
  return
57
61
  try:
58
- event = json.loads(line.decode("utf-8"))
62
+ event_payload = json.loads(input_line.decode("utf-8"))
59
63
  except json.JSONDecodeError:
60
64
  continue
61
65
 
62
- req_id = event.get("id")
63
- if not isinstance(req_id, str):
66
+ request_id_value = event_payload.get("id")
67
+ if not isinstance(request_id_value, str):
64
68
  continue
65
69
 
66
70
  with self._pending_lock:
67
- pending = self._pending.get(req_id)
68
- if pending is None:
71
+ pending_request = self._pending_requests.get(request_id_value)
72
+ if pending_request is None:
69
73
  continue
70
74
 
71
- loop, queue = pending
75
+ event_loop, response_queue = pending_request
72
76
  try:
73
- loop.call_soon_threadsafe(queue.put_nowait, event)
77
+ event_loop.call_soon_threadsafe(response_queue.put_nowait, event_payload)
74
78
  except RuntimeError:
75
79
  # Loop is closed or unavailable; ignore.
76
80
  continue
77
81
 
78
82
  async def aclose(self) -> None:
79
- self._stop.set()
83
+ self._shutdown_event.set()
80
84
  thread = self._reader_thread
81
85
  if thread is None or not thread.is_alive():
82
86
  return
@@ -86,28 +90,28 @@ class _StdioBrokerTransport:
86
90
  return
87
91
 
88
92
  async def request(
89
- self, req_id: str, method: str, params: dict[str, Any]
93
+ self, request_id: str, method: str, params: dict[str, Any]
90
94
  ) -> AsyncIterator[dict[str, Any]]:
91
95
  self._ensure_reader_thread()
92
- loop = asyncio.get_running_loop()
93
- queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
96
+ event_loop = asyncio.get_running_loop()
97
+ response_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
94
98
  with self._pending_lock:
95
- self._pending[req_id] = (loop, queue)
99
+ self._pending_requests[request_id] = (event_loop, response_queue)
96
100
 
97
101
  try:
98
- payload = _json_dumps({"id": req_id, "method": method, "params": params})
102
+ request_payload = _json_dumps({"id": request_id, "method": method, "params": params})
99
103
  with self._write_lock:
100
- sys.stderr.write(f"{STDIO_REQUEST_PREFIX}{payload}\n")
104
+ sys.stderr.write(f"{STDIO_REQUEST_PREFIX}{request_payload}\n")
101
105
  sys.stderr.flush()
102
106
 
103
107
  while True:
104
- event = await queue.get()
105
- yield event
106
- if event.get("event") in ("done", "error"):
108
+ event_payload = await response_queue.get()
109
+ yield event_payload
110
+ if event_payload.get("event") in ("done", "error"):
107
111
  return
108
112
  finally:
109
113
  with self._pending_lock:
110
- self._pending.pop(req_id, None)
114
+ self._pending_requests.pop(request_id, None)
111
115
 
112
116
 
113
117
  _STDIO_TRANSPORT = _StdioBrokerTransport()
@@ -129,59 +133,61 @@ class BrokerClient:
129
133
  return cls(socket_path)
130
134
 
131
135
  async def _request(self, method: str, params: dict[str, Any]) -> AsyncIterator[dict[str, Any]]:
132
- req_id = uuid.uuid4().hex
136
+ request_id = uuid.uuid4().hex
133
137
 
134
138
  if self.socket_path == STDIO_TRANSPORT_VALUE:
135
- async for event in _STDIO_TRANSPORT.request(req_id, method, params):
139
+ async for event_payload in _STDIO_TRANSPORT.request(request_id, method, params):
136
140
  # Responses are already correlated by req_id; add a defensive filter anyway.
137
- if event.get("id") == req_id:
138
- yield event
141
+ if event_payload.get("id") == request_id:
142
+ yield event_payload
139
143
  return
140
144
 
141
145
  if self.socket_path.startswith(("tcp://", "tls://")):
142
146
  use_tls = self.socket_path.startswith("tls://")
143
- host_port = self.socket_path.split("://", 1)[1]
144
- if "/" in host_port:
145
- host_port = host_port.split("/", 1)[0]
146
- if ":" not in host_port:
147
+ host_and_port = self.socket_path.split("://", 1)[1]
148
+ if "/" in host_and_port:
149
+ host_and_port = host_and_port.split("/", 1)[0]
150
+ if ":" not in host_and_port:
147
151
  raise ValueError(
148
- f"Invalid broker endpoint: {self.socket_path}. Expected tcp://host:port or tls://host:port"
152
+ "Invalid broker endpoint: "
153
+ f"{self.socket_path}. Expected tcp://host:port or tls://host:port"
149
154
  )
150
- host, port_str = host_port.rsplit(":", 1)
155
+ host, port_text = host_and_port.rsplit(":", 1)
151
156
  try:
152
- port = int(port_str)
153
- except ValueError as e:
154
- raise ValueError(f"Invalid broker port in endpoint: {self.socket_path}") from e
157
+ port = int(port_text)
158
+ except ValueError as error:
159
+ raise ValueError(f"Invalid broker port in endpoint: {self.socket_path}") from error
155
160
 
156
- ssl_ctx: ssl.SSLContext | None = None
161
+ ssl_context: ssl.SSLContext | None = None
157
162
  if use_tls:
158
- ssl_ctx = ssl.create_default_context()
163
+ ssl_context = ssl.create_default_context()
159
164
  cafile = os.environ.get("TACTUS_BROKER_TLS_CA_FILE")
160
165
  if cafile:
161
- ssl_ctx.load_verify_locations(cafile=cafile)
166
+ ssl_context.load_verify_locations(cafile=cafile)
162
167
 
163
168
  if os.environ.get("TACTUS_BROKER_TLS_INSECURE") in ("1", "true", "yes"):
164
- ssl_ctx.check_hostname = False
165
- ssl_ctx.verify_mode = ssl.CERT_NONE
169
+ ssl_context.check_hostname = False
170
+ ssl_context.verify_mode = ssl.CERT_NONE
166
171
 
167
- reader, writer = await asyncio.open_connection(host, port, ssl=ssl_ctx)
172
+ reader, writer = await asyncio.open_connection(host, port, ssl=ssl_context)
168
173
  logger.info(
169
- f"[BROKER_CLIENT] Writing message to broker, params keys: {list(params.keys())}"
174
+ "[BROKER_CLIENT] Writing message to broker, params keys: %s",
175
+ list(params.keys()),
170
176
  )
171
177
  try:
172
- await write_message(writer, {"id": req_id, "method": method, "params": params})
173
- except TypeError as e:
174
- logger.error(f"[BROKER_CLIENT] JSON serialization error: {e}")
175
- logger.error(f"[BROKER_CLIENT] Params: {params}")
178
+ await write_message(writer, {"id": request_id, "method": method, "params": params})
179
+ except TypeError as error:
180
+ logger.error("[BROKER_CLIENT] JSON serialization error: %s", error)
181
+ logger.error("[BROKER_CLIENT] Params: %s", params)
176
182
  raise
177
183
 
178
184
  try:
179
185
  while True:
180
- event = await read_message(reader)
181
- if event.get("id") != req_id:
186
+ event_payload = await read_message(reader)
187
+ if event_payload.get("id") != request_id:
182
188
  continue
183
- yield event
184
- if event.get("event") in ("done", "error"):
189
+ yield event_payload
190
+ if event_payload.get("event") in ("done", "error"):
185
191
  return
186
192
  finally:
187
193
  try:
@@ -191,16 +197,16 @@ class BrokerClient:
191
197
  pass
192
198
 
193
199
  reader, writer = await asyncio.open_unix_connection(self.socket_path)
194
- await write_message(writer, {"id": req_id, "method": method, "params": params})
200
+ await write_message(writer, {"id": request_id, "method": method, "params": params})
195
201
 
196
202
  try:
197
203
  while True:
198
- event = await read_message(reader)
204
+ event_payload = await read_message(reader)
199
205
  # Ignore unrelated messages (defensive; current server is 1-req/conn).
200
- if event.get("id") != req_id:
206
+ if event_payload.get("id") != request_id:
201
207
  continue
202
- yield event
203
- if event.get("event") in ("done", "error"):
208
+ yield event_payload
209
+ if event_payload.get("event") in ("done", "error"):
204
210
  return
205
211
  finally:
206
212
  try:
@@ -221,34 +227,25 @@ class BrokerClient:
221
227
  tools: Optional[list[dict[str, Any]]] = None,
222
228
  tool_choice: Optional[str] = None,
223
229
  ) -> AsyncIterator[dict[str, Any]]:
224
- params: dict[str, Any] = {
230
+ request_params: dict[str, Any] = {
225
231
  "provider": provider,
226
232
  "model": model,
227
233
  "messages": messages,
228
234
  "stream": stream,
229
235
  }
230
236
  if temperature is not None:
231
- params["temperature"] = temperature
237
+ request_params["temperature"] = temperature
232
238
  if max_tokens is not None:
233
- params["max_tokens"] = max_tokens
239
+ request_params["max_tokens"] = max_tokens
234
240
  if tools is not None:
235
- params["tools"] = tools
236
- import logging
237
-
238
- logger = logging.getLogger(__name__)
239
- logger.info(f"[BROKER_CLIENT] Adding {len(tools)} tools to params")
241
+ request_params["tools"] = tools
242
+ logger.info("[BROKER_CLIENT] Adding %s tools to params", len(tools))
240
243
  else:
241
- import logging
242
-
243
- logger = logging.getLogger(__name__)
244
244
  logger.warning("[BROKER_CLIENT] No tools to add to params")
245
245
  if tool_choice is not None:
246
- params["tool_choice"] = tool_choice
247
- import logging
248
-
249
- logger = logging.getLogger(__name__)
250
- logger.info(f"[BROKER_CLIENT] Adding tool_choice={tool_choice} to params")
251
- return self._request("llm.chat", params)
246
+ request_params["tool_choice"] = tool_choice
247
+ logger.info("[BROKER_CLIENT] Adding tool_choice=%s to params", tool_choice)
248
+ return self._request("llm.chat", request_params)
252
249
 
253
250
  async def call_tool(self, *, name: str, args: dict[str, Any]) -> Any:
254
251
  """
@@ -261,14 +258,14 @@ class BrokerClient:
261
258
  if not isinstance(args, dict):
262
259
  raise ValueError("tool args must be an object")
263
260
 
264
- async for event in self._request("tool.call", {"name": name, "args": args}):
265
- event_type = event.get("event")
261
+ async for event_payload in self._request("tool.call", {"name": name, "args": args}):
262
+ event_type = event_payload.get("event")
266
263
  if event_type == "done":
267
- data = event.get("data") or {}
264
+ data = event_payload.get("data") or {}
268
265
  return data.get("result")
269
266
  if event_type == "error":
270
- err = event.get("error") or {}
271
- raise RuntimeError(err.get("message") or "Broker tool error")
267
+ error_payload = event_payload.get("error") or {}
268
+ raise RuntimeError(error_payload.get("message") or "Broker tool error")
272
269
 
273
270
  raise RuntimeError("Broker tool call ended without a response")
274
271
 
tactus/broker/protocol.py CHANGED
@@ -12,10 +12,10 @@ Example:
12
12
  {"id":"abc","method":"llm.chat","params":{...}}
13
13
  """
14
14
 
15
- import json
16
15
  import asyncio
17
16
  import logging
18
- from typing import Any, Dict, AsyncIterator
17
+ import json
18
+ from typing import Any, AsyncIterator
19
19
 
20
20
  import anyio
21
21
  from anyio.streams.buffered import BufferedByteReceiveStream
@@ -27,7 +27,33 @@ LENGTH_PREFIX_SIZE = 11 # "0000000123\n"
27
27
  MAX_MESSAGE_SIZE = 100 * 1024 * 1024 # 100MB safety limit
28
28
 
29
29
 
30
- async def write_message(writer: asyncio.StreamWriter, message: Dict[str, Any]) -> None:
30
+ def _parse_length_prefix(length_prefix_bytes: bytes) -> int:
31
+ try:
32
+ length_text = length_prefix_bytes[:10].decode("ascii")
33
+ payload_length = int(length_text)
34
+ except (ValueError, UnicodeDecodeError) as error:
35
+ raise ValueError(f"Invalid length prefix: {length_prefix_bytes!r}") from error
36
+
37
+ if payload_length > MAX_MESSAGE_SIZE:
38
+ raise ValueError(f"Message size {payload_length} exceeds maximum {MAX_MESSAGE_SIZE}")
39
+
40
+ if payload_length == 0:
41
+ raise ValueError("Zero-length message not allowed")
42
+
43
+ return payload_length
44
+
45
+
46
+ def _serialize_json_payload(message: dict[str, Any]) -> bytes:
47
+ json_payload_bytes = json.dumps(message).encode("utf-8")
48
+ payload_length = len(json_payload_bytes)
49
+
50
+ if payload_length > MAX_MESSAGE_SIZE:
51
+ raise ValueError(f"Message size {payload_length} exceeds maximum {MAX_MESSAGE_SIZE}")
52
+
53
+ return json_payload_bytes
54
+
55
+
56
+ async def write_message(writer: asyncio.StreamWriter, message: dict[str, Any]) -> None:
31
57
  """
32
58
  Write a JSON message with length prefix.
33
59
 
@@ -38,20 +64,17 @@ async def write_message(writer: asyncio.StreamWriter, message: Dict[str, Any]) -
38
64
  Raises:
39
65
  ValueError: If message is too large
40
66
  """
41
- json_bytes = json.dumps(message).encode("utf-8")
42
- length = len(json_bytes)
43
-
44
- if length > MAX_MESSAGE_SIZE:
45
- raise ValueError(f"Message size {length} exceeds maximum {MAX_MESSAGE_SIZE}")
67
+ json_payload_bytes = _serialize_json_payload(message)
68
+ payload_length = len(json_payload_bytes)
46
69
 
47
70
  # Write 10-digit length prefix + newline
48
- length_prefix = f"{length:010d}\n".encode("ascii")
71
+ length_prefix = f"{payload_length:010d}\n".encode("ascii")
49
72
  writer.write(length_prefix)
50
- writer.write(json_bytes)
73
+ writer.write(json_payload_bytes)
51
74
  await writer.drain()
52
75
 
53
76
 
54
- async def read_message(reader: asyncio.StreamReader) -> Dict[str, Any]:
77
+ async def read_message(reader: asyncio.StreamReader) -> dict[str, Any]:
55
78
  """
56
79
  Read a JSON message with length prefix.
57
80
 
@@ -66,35 +89,25 @@ async def read_message(reader: asyncio.StreamReader) -> Dict[str, Any]:
66
89
  ValueError: If message is invalid or too large
67
90
  """
68
91
  # Read exactly 11 bytes for length prefix
69
- length_bytes = await reader.readexactly(LENGTH_PREFIX_SIZE)
92
+ length_prefix_bytes = await reader.readexactly(LENGTH_PREFIX_SIZE)
70
93
 
71
- if not length_bytes:
94
+ if not length_prefix_bytes:
72
95
  raise EOFError("Connection closed")
73
96
 
74
- try:
75
- length_str = length_bytes[:10].decode("ascii")
76
- length = int(length_str)
77
- except (ValueError, UnicodeDecodeError) as e:
78
- raise ValueError(f"Invalid length prefix: {length_bytes!r}") from e
79
-
80
- if length > MAX_MESSAGE_SIZE:
81
- raise ValueError(f"Message size {length} exceeds maximum {MAX_MESSAGE_SIZE}")
82
-
83
- if length == 0:
84
- raise ValueError("Zero-length message not allowed")
97
+ payload_length = _parse_length_prefix(length_prefix_bytes)
85
98
 
86
99
  # Read exactly that many bytes for the JSON payload
87
- json_bytes = await reader.readexactly(length)
100
+ json_payload_bytes = await reader.readexactly(payload_length)
88
101
 
89
102
  try:
90
- message = json.loads(json_bytes.decode("utf-8"))
91
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
92
- raise ValueError("Invalid JSON payload") from e
103
+ message = json.loads(json_payload_bytes.decode("utf-8"))
104
+ except (json.JSONDecodeError, UnicodeDecodeError) as error:
105
+ raise ValueError("Invalid JSON payload") from error
93
106
 
94
107
  return message
95
108
 
96
109
 
97
- async def read_messages(reader: asyncio.StreamReader) -> AsyncIterator[Dict[str, Any]]:
110
+ async def read_messages(reader: asyncio.StreamReader) -> AsyncIterator[dict[str, Any]]:
98
111
  """
99
112
  Read a stream of length-prefixed JSON messages.
100
113
 
@@ -110,14 +123,14 @@ async def read_messages(reader: asyncio.StreamReader) -> AsyncIterator[Dict[str,
110
123
  while True:
111
124
  message = await read_message(reader)
112
125
  yield message
113
- except EOFError:
114
- return
115
126
  except asyncio.IncompleteReadError:
116
127
  return
128
+ except EOFError:
129
+ return
117
130
 
118
131
 
119
132
  # AnyIO-compatible versions for broker server
120
- async def write_message_anyio(stream: anyio.abc.ByteStream, message: Dict[str, Any]) -> None:
133
+ async def write_message_anyio(stream: anyio.abc.ByteStream, message: dict[str, Any]) -> None:
121
134
  """
122
135
  Write a JSON message with length prefix using AnyIO streams.
123
136
 
@@ -128,19 +141,16 @@ async def write_message_anyio(stream: anyio.abc.ByteStream, message: Dict[str, A
128
141
  Raises:
129
142
  ValueError: If message is too large
130
143
  """
131
- json_bytes = json.dumps(message).encode("utf-8")
132
- length = len(json_bytes)
133
-
134
- if length > MAX_MESSAGE_SIZE:
135
- raise ValueError(f"Message size {length} exceeds maximum {MAX_MESSAGE_SIZE}")
144
+ json_payload_bytes = _serialize_json_payload(message)
145
+ payload_length = len(json_payload_bytes)
136
146
 
137
147
  # Write 10-digit length prefix + newline
138
- length_prefix = f"{length:010d}\n".encode("ascii")
148
+ length_prefix = f"{payload_length:010d}\n".encode("ascii")
139
149
  await stream.send(length_prefix)
140
- await stream.send(json_bytes)
150
+ await stream.send(json_payload_bytes)
141
151
 
142
152
 
143
- async def read_message_anyio(stream: BufferedByteReceiveStream) -> Dict[str, Any]:
153
+ async def read_message_anyio(stream: BufferedByteReceiveStream) -> dict[str, Any]:
144
154
  """
145
155
  Read a JSON message with length prefix using AnyIO streams.
146
156
 
@@ -155,29 +165,19 @@ async def read_message_anyio(stream: BufferedByteReceiveStream) -> Dict[str, Any
155
165
  ValueError: If message is invalid or too large
156
166
  """
157
167
  # Read exactly 11 bytes for length prefix
158
- length_bytes = await stream.receive_exactly(LENGTH_PREFIX_SIZE)
168
+ length_prefix_bytes = await stream.receive_exactly(LENGTH_PREFIX_SIZE)
159
169
 
160
- if not length_bytes:
170
+ if not length_prefix_bytes:
161
171
  raise EOFError("Connection closed")
162
172
 
163
- try:
164
- length_str = length_bytes[:10].decode("ascii")
165
- length = int(length_str)
166
- except (ValueError, UnicodeDecodeError) as e:
167
- raise ValueError(f"Invalid length prefix: {length_bytes!r}") from e
168
-
169
- if length > MAX_MESSAGE_SIZE:
170
- raise ValueError(f"Message size {length} exceeds maximum {MAX_MESSAGE_SIZE}")
171
-
172
- if length == 0:
173
- raise ValueError("Zero-length message not allowed")
173
+ payload_length = _parse_length_prefix(length_prefix_bytes)
174
174
 
175
175
  # Read exactly that many bytes for the JSON payload
176
- json_bytes = await stream.receive_exactly(length)
176
+ json_payload_bytes = await stream.receive_exactly(payload_length)
177
177
 
178
178
  try:
179
- message = json.loads(json_bytes.decode("utf-8"))
180
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
181
- raise ValueError("Invalid JSON payload") from e
179
+ message = json.loads(json_payload_bytes.decode("utf-8"))
180
+ except (json.JSONDecodeError, UnicodeDecodeError) as error:
181
+ raise ValueError("Invalid JSON payload") from error
182
182
 
183
183
  return message