tactus 0.36.0__py3-none-any.whl → 0.37.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tactus/__init__.py +1 -1
- tactus/adapters/channels/base.py +20 -2
- tactus/adapters/channels/broker.py +1 -0
- tactus/adapters/channels/host.py +3 -1
- tactus/adapters/channels/ipc.py +18 -3
- tactus/adapters/channels/sse.py +2 -0
- tactus/adapters/mcp_manager.py +24 -7
- tactus/backends/http_backend.py +2 -2
- tactus/backends/pytorch_backend.py +2 -2
- tactus/broker/client.py +3 -3
- tactus/broker/server.py +17 -5
- tactus/core/dsl_stubs.py +3 -3
- tactus/core/execution_context.py +23 -23
- tactus/core/message_history_manager.py +2 -2
- tactus/core/output_validator.py +6 -6
- tactus/core/registry.py +29 -29
- tactus/core/runtime.py +44 -44
- tactus/dspy/broker_lm.py +13 -7
- tactus/dspy/config.py +7 -4
- tactus/primitives/model.py +1 -1
- tactus/primitives/procedure.py +1 -1
- tactus/primitives/state.py +2 -2
- tactus/sandbox/container_runner.py +11 -6
- tactus/testing/context.py +6 -6
- tactus/testing/evaluation_runner.py +5 -5
- tactus/testing/steps/builtin.py +2 -2
- tactus/testing/test_runner.py +6 -4
- tactus/utils/asyncio_helpers.py +2 -1
- {tactus-0.36.0.dist-info → tactus-0.37.0.dist-info}/METADATA +7 -5
- {tactus-0.36.0.dist-info → tactus-0.37.0.dist-info}/RECORD +33 -33
- {tactus-0.36.0.dist-info → tactus-0.37.0.dist-info}/WHEEL +0 -0
- {tactus-0.36.0.dist-info → tactus-0.37.0.dist-info}/entry_points.txt +0 -0
- {tactus-0.36.0.dist-info → tactus-0.37.0.dist-info}/licenses/LICENSE +0 -0
tactus/__init__.py
CHANGED
tactus/adapters/channels/base.py
CHANGED
|
@@ -9,7 +9,7 @@ requiring separate processes (e.g., Discord WebSocket gateway).
|
|
|
9
9
|
import asyncio
|
|
10
10
|
import logging
|
|
11
11
|
from abc import ABC, abstractmethod
|
|
12
|
-
from typing import AsyncIterator
|
|
12
|
+
from typing import AsyncIterator, Optional
|
|
13
13
|
|
|
14
14
|
from tactus.protocols.control import (
|
|
15
15
|
ControlRequest,
|
|
@@ -54,7 +54,20 @@ class InProcessChannel(ABC):
|
|
|
54
54
|
|
|
55
55
|
def __init__(self):
|
|
56
56
|
"""Initialize the channel with an internal response queue."""
|
|
57
|
-
self._response_queue: asyncio.Queue[ControlResponse] =
|
|
57
|
+
self._response_queue: Optional[asyncio.Queue[ControlResponse]] = None
|
|
58
|
+
self._shutdown_event: Optional[asyncio.Event] = None
|
|
59
|
+
|
|
60
|
+
def _ensure_asyncio_primitives(self) -> None:
|
|
61
|
+
if self._response_queue is not None and self._shutdown_event is not None:
|
|
62
|
+
return
|
|
63
|
+
try:
|
|
64
|
+
asyncio.get_running_loop()
|
|
65
|
+
except RuntimeError as error:
|
|
66
|
+
raise RuntimeError(
|
|
67
|
+
"InProcessChannel requires a running event loop before use. "
|
|
68
|
+
"Initialize it from within an async context."
|
|
69
|
+
) from error
|
|
70
|
+
self._response_queue = asyncio.Queue()
|
|
58
71
|
self._shutdown_event = asyncio.Event()
|
|
59
72
|
|
|
60
73
|
@property
|
|
@@ -105,6 +118,7 @@ class InProcessChannel(ABC):
|
|
|
105
118
|
Yields:
|
|
106
119
|
ControlResponse as they are received
|
|
107
120
|
"""
|
|
121
|
+
self._ensure_asyncio_primitives()
|
|
108
122
|
while not self._shutdown_event.is_set():
|
|
109
123
|
try:
|
|
110
124
|
# Use wait_for with timeout to check shutdown periodically
|
|
@@ -149,6 +163,7 @@ class InProcessChannel(ABC):
|
|
|
149
163
|
Override for additional cleanup (close connections, etc.).
|
|
150
164
|
"""
|
|
151
165
|
logger.info("%s: shutting down", self.channel_id)
|
|
166
|
+
self._ensure_asyncio_primitives()
|
|
152
167
|
self._shutdown_event.set()
|
|
153
168
|
|
|
154
169
|
def push_response(self, response: ControlResponse) -> None:
|
|
@@ -164,6 +179,7 @@ class InProcessChannel(ABC):
|
|
|
164
179
|
response: ControlResponse to add to queue
|
|
165
180
|
"""
|
|
166
181
|
try:
|
|
182
|
+
self._ensure_asyncio_primitives()
|
|
167
183
|
self._response_queue.put_nowait(response)
|
|
168
184
|
except Exception as error:
|
|
169
185
|
logger.error("%s: failed to queue response: %s", self.channel_id, error)
|
|
@@ -180,4 +196,6 @@ class InProcessChannel(ABC):
|
|
|
180
196
|
response: ControlResponse to add to queue
|
|
181
197
|
loop: The event loop to use for thread-safe call
|
|
182
198
|
"""
|
|
199
|
+
if self._response_queue is None:
|
|
200
|
+
loop.call_soon_threadsafe(self._ensure_asyncio_primitives)
|
|
183
201
|
loop.call_soon_threadsafe(self._response_queue.put_nowait, response)
|
tactus/adapters/channels/host.py
CHANGED
|
@@ -90,12 +90,14 @@ class HostControlChannel(InProcessChannel):
|
|
|
90
90
|
request.request_id,
|
|
91
91
|
)
|
|
92
92
|
|
|
93
|
+
self._ensure_asyncio_primitives()
|
|
94
|
+
|
|
93
95
|
# Store for background thread access
|
|
94
96
|
self._current_request = request
|
|
95
97
|
self._cancel_event.clear()
|
|
96
98
|
|
|
97
99
|
# Capture event loop for thread-safe response pushing
|
|
98
|
-
self._event_loop = asyncio.
|
|
100
|
+
self._event_loop = asyncio.get_running_loop()
|
|
99
101
|
|
|
100
102
|
# Display the request (synchronous, before starting thread)
|
|
101
103
|
self._display_request(request)
|
tactus/adapters/channels/ipc.py
CHANGED
|
@@ -46,10 +46,22 @@ class IPCControlChannel:
|
|
|
46
46
|
|
|
47
47
|
self._server: Optional[asyncio.Server] = None
|
|
48
48
|
self._clients: dict[str, asyncio.StreamWriter] = {} # client_id -> writer
|
|
49
|
-
self._response_queue: asyncio.Queue[ControlResponse] =
|
|
49
|
+
self._response_queue: Optional[asyncio.Queue[ControlResponse]] = None
|
|
50
50
|
self._pending_requests: dict[str, ControlRequest] = {} # request_id -> request
|
|
51
51
|
self._initialized = False
|
|
52
52
|
|
|
53
|
+
def _ensure_response_queue(self) -> asyncio.Queue[ControlResponse]:
|
|
54
|
+
if self._response_queue is None:
|
|
55
|
+
try:
|
|
56
|
+
asyncio.get_running_loop()
|
|
57
|
+
except RuntimeError as error:
|
|
58
|
+
raise RuntimeError(
|
|
59
|
+
"IPCControlChannel requires a running event loop before use. "
|
|
60
|
+
"Initialize it from within an async context."
|
|
61
|
+
) from error
|
|
62
|
+
self._response_queue = asyncio.Queue()
|
|
63
|
+
return self._response_queue
|
|
64
|
+
|
|
53
65
|
@property
|
|
54
66
|
def capabilities(self) -> ChannelCapabilities:
|
|
55
67
|
"""IPC supports all request types and can respond synchronously."""
|
|
@@ -68,6 +80,7 @@ class IPCControlChannel:
|
|
|
68
80
|
return
|
|
69
81
|
|
|
70
82
|
logger.info("%s: initializing...", self.channel_id)
|
|
83
|
+
self._ensure_response_queue()
|
|
71
84
|
|
|
72
85
|
# Remove old socket file if it exists
|
|
73
86
|
if os.path.exists(self.socket_path):
|
|
@@ -168,8 +181,9 @@ class IPCControlChannel:
|
|
|
168
181
|
Yields:
|
|
169
182
|
ControlResponse objects
|
|
170
183
|
"""
|
|
184
|
+
response_queue = self._ensure_response_queue()
|
|
171
185
|
while True:
|
|
172
|
-
response = await
|
|
186
|
+
response = await response_queue.get()
|
|
173
187
|
logger.info(
|
|
174
188
|
"%s: received response for %s",
|
|
175
189
|
self.channel_id,
|
|
@@ -264,6 +278,7 @@ class IPCControlChannel:
|
|
|
264
278
|
self._clients[client_id] = writer
|
|
265
279
|
|
|
266
280
|
try:
|
|
281
|
+
response_queue = self._ensure_response_queue()
|
|
267
282
|
# Send any pending requests to the new client
|
|
268
283
|
for request_id, request_data in self._pending_requests.items():
|
|
269
284
|
try:
|
|
@@ -302,7 +317,7 @@ class IPCControlChannel:
|
|
|
302
317
|
timed_out=message.get("timed_out", False),
|
|
303
318
|
channel_id=self.channel_id,
|
|
304
319
|
)
|
|
305
|
-
await
|
|
320
|
+
await response_queue.put(response)
|
|
306
321
|
logger.info(
|
|
307
322
|
"%s: received response for %s",
|
|
308
323
|
self.channel_id,
|
tactus/adapters/channels/sse.py
CHANGED
|
@@ -267,6 +267,7 @@ class SSEControlChannel(InProcessChannel):
|
|
|
267
267
|
self, request_id: str, response: ControlResponse
|
|
268
268
|
) -> None:
|
|
269
269
|
try:
|
|
270
|
+
self._ensure_asyncio_primitives()
|
|
270
271
|
event_loop = asyncio.get_event_loop()
|
|
271
272
|
if event_loop.is_running():
|
|
272
273
|
asyncio.run_coroutine_threadsafe(self._response_queue.put(response), event_loop)
|
|
@@ -326,4 +327,5 @@ class SSEControlChannel(InProcessChannel):
|
|
|
326
327
|
async def shutdown(self) -> None:
|
|
327
328
|
"""Shutdown SSE channel."""
|
|
328
329
|
logger.info("%s: shutting down", self.channel_id)
|
|
330
|
+
self._ensure_asyncio_primitives()
|
|
329
331
|
self._shutdown_event.set()
|
tactus/adapters/mcp_manager.py
CHANGED
|
@@ -5,18 +5,32 @@ Manages multiple MCP server connections using Pydantic AI's native MCPServerStdi
|
|
|
5
5
|
Handles lifecycle, tool prefixing, and tool call tracking.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
8
10
|
import logging
|
|
9
11
|
import os
|
|
10
12
|
import re
|
|
11
13
|
import asyncio
|
|
12
14
|
from contextlib import AsyncExitStack
|
|
13
|
-
from typing import Any
|
|
14
|
-
|
|
15
|
-
from pydantic_ai.mcp import MCPServerStdio
|
|
15
|
+
from typing import Any, Optional
|
|
16
16
|
|
|
17
17
|
logger = logging.getLogger(__name__)
|
|
18
18
|
|
|
19
19
|
|
|
20
|
+
MCPServerStdio: Optional[Any] = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _require_mcp_server_stdio():
|
|
24
|
+
try:
|
|
25
|
+
from pydantic_ai.mcp import MCPServerStdio
|
|
26
|
+
except ImportError as import_error:
|
|
27
|
+
raise RuntimeError(
|
|
28
|
+
"MCP support requires optional dependencies. "
|
|
29
|
+
'Install with `pip install "pydantic-ai-slim[mcp]"`.'
|
|
30
|
+
) from import_error
|
|
31
|
+
return MCPServerStdio
|
|
32
|
+
|
|
33
|
+
|
|
20
34
|
def substitute_env_vars(value: Any) -> Any:
|
|
21
35
|
"""
|
|
22
36
|
Replace ${VAR} with environment variable values.
|
|
@@ -55,8 +69,8 @@ class MCPServerManager:
|
|
|
55
69
|
"""
|
|
56
70
|
self.configs = server_configs
|
|
57
71
|
self.tool_primitive = tool_primitive
|
|
58
|
-
self.servers: list[
|
|
59
|
-
self.server_toolsets: dict[str,
|
|
72
|
+
self.servers: list[Any] = []
|
|
73
|
+
self.server_toolsets: dict[str, Any] = {} # Map server names to toolsets
|
|
60
74
|
self._exit_stack = AsyncExitStack()
|
|
61
75
|
logger.info("MCPServerManager initialized with %s server(s)", len(server_configs))
|
|
62
76
|
|
|
@@ -64,7 +78,7 @@ class MCPServerManager:
|
|
|
64
78
|
"""Connect to all configured MCP servers."""
|
|
65
79
|
for name, config in self.configs.items():
|
|
66
80
|
# Retry a few times for transient stdio startup issues.
|
|
67
|
-
last_error: Exception
|
|
81
|
+
last_error: Optional[Exception] = None
|
|
68
82
|
for attempt in range(1, 4):
|
|
69
83
|
try:
|
|
70
84
|
logger.info(
|
|
@@ -77,6 +91,9 @@ class MCPServerManager:
|
|
|
77
91
|
resolved_config = substitute_env_vars(config)
|
|
78
92
|
|
|
79
93
|
# Create base server
|
|
94
|
+
MCPServerStdio = globals().get("MCPServerStdio")
|
|
95
|
+
if MCPServerStdio is None:
|
|
96
|
+
MCPServerStdio = _require_mcp_server_stdio()
|
|
80
97
|
server = MCPServerStdio(
|
|
81
98
|
command=resolved_config["command"],
|
|
82
99
|
args=resolved_config.get("args", []),
|
|
@@ -190,7 +207,7 @@ class MCPServerManager:
|
|
|
190
207
|
|
|
191
208
|
return trace_tool_call
|
|
192
209
|
|
|
193
|
-
def get_toolsets(self) -> list[
|
|
210
|
+
def get_toolsets(self) -> list[Any]:
|
|
194
211
|
"""
|
|
195
212
|
Return list of connected servers as toolsets.
|
|
196
213
|
|
tactus/backends/http_backend.py
CHANGED
|
@@ -3,7 +3,7 @@ HTTP model backend for REST endpoint inference.
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any, Dict, Optional
|
|
7
7
|
|
|
8
8
|
import httpx
|
|
9
9
|
|
|
@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
|
|
|
13
13
|
class HTTPModelBackend:
|
|
14
14
|
"""Model backend that calls HTTP REST endpoints."""
|
|
15
15
|
|
|
16
|
-
def __init__(self, endpoint: str, timeout: float = 30.0, headers:
|
|
16
|
+
def __init__(self, endpoint: str, timeout: float = 30.0, headers: Optional[Dict] = None):
|
|
17
17
|
"""
|
|
18
18
|
Initialize HTTP model backend.
|
|
19
19
|
|
|
@@ -4,7 +4,7 @@ PyTorch model backend for .pt file inference.
|
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any, List, Optional
|
|
8
8
|
|
|
9
9
|
logger = logging.getLogger(__name__)
|
|
10
10
|
|
|
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
|
|
12
12
|
class PyTorchModelBackend:
|
|
13
13
|
"""Model backend that loads and runs PyTorch models."""
|
|
14
14
|
|
|
15
|
-
def __init__(self, path: str, device: str = "cpu", labels:
|
|
15
|
+
def __init__(self, path: str, device: str = "cpu", labels: Optional[List[str]] = None):
|
|
16
16
|
"""
|
|
17
17
|
Initialize PyTorch model backend.
|
|
18
18
|
|
tactus/broker/client.py
CHANGED
|
@@ -16,7 +16,7 @@ import sys
|
|
|
16
16
|
import threading
|
|
17
17
|
import uuid
|
|
18
18
|
from pathlib import Path
|
|
19
|
-
from typing import Any, AsyncIterator, Optional
|
|
19
|
+
from typing import Any, AsyncIterator, Optional, Union
|
|
20
20
|
|
|
21
21
|
from tactus.broker.protocol import read_message, write_message
|
|
22
22
|
from tactus.broker.stdio import STDIO_REQUEST_PREFIX, STDIO_TRANSPORT_VALUE
|
|
@@ -122,7 +122,7 @@ async def close_stdio_transport() -> None:
|
|
|
122
122
|
|
|
123
123
|
|
|
124
124
|
class BrokerClient:
|
|
125
|
-
def __init__(self, socket_path: str
|
|
125
|
+
def __init__(self, socket_path: Union[str, Path]):
|
|
126
126
|
self.socket_path = str(socket_path)
|
|
127
127
|
|
|
128
128
|
@classmethod
|
|
@@ -158,7 +158,7 @@ class BrokerClient:
|
|
|
158
158
|
except ValueError as error:
|
|
159
159
|
raise ValueError(f"Invalid broker port in endpoint: {self.socket_path}") from error
|
|
160
160
|
|
|
161
|
-
ssl_context: ssl.SSLContext
|
|
161
|
+
ssl_context: Optional[ssl.SSLContext] = None
|
|
162
162
|
if use_tls:
|
|
163
163
|
ssl_context = ssl.create_default_context()
|
|
164
164
|
cafile = os.environ.get("TACTUS_BROKER_TLS_CA_FILE")
|
tactus/broker/server.py
CHANGED
|
@@ -29,6 +29,18 @@ from tactus.broker.protocol import (
|
|
|
29
29
|
logger = logging.getLogger(__name__)
|
|
30
30
|
|
|
31
31
|
|
|
32
|
+
try:
|
|
33
|
+
from builtins import BaseExceptionGroup as BaseExceptionGroup
|
|
34
|
+
except ImportError: # pragma: no cover - Python < 3.11 fallback
|
|
35
|
+
|
|
36
|
+
class BaseExceptionGroup(Exception):
|
|
37
|
+
"""Minimal BaseExceptionGroup fallback for Python < 3.11."""
|
|
38
|
+
|
|
39
|
+
def __init__(self, message: str, exceptions: list[BaseException]):
|
|
40
|
+
super().__init__(message)
|
|
41
|
+
self.exceptions = exceptions
|
|
42
|
+
|
|
43
|
+
|
|
32
44
|
def _json_dumps(obj: Any) -> str:
|
|
33
45
|
return json.dumps(obj, ensure_ascii=False, separators=(",", ":"))
|
|
34
46
|
|
|
@@ -170,7 +182,7 @@ class _BaseBrokerServer:
|
|
|
170
182
|
control_handler: Optional[Callable[[dict], Awaitable[dict]]] = None,
|
|
171
183
|
):
|
|
172
184
|
self._listener = None
|
|
173
|
-
self._serve_task: asyncio.Task[None]
|
|
185
|
+
self._serve_task: Optional[asyncio.Task[None]] = None
|
|
174
186
|
self._openai = openai_backend or OpenAIChatBackend()
|
|
175
187
|
self._tools = tool_registry or HostToolRegistry.default()
|
|
176
188
|
self._event_handler = event_handler
|
|
@@ -1012,7 +1024,7 @@ class BrokerServer(_BaseBrokerServer):
|
|
|
1012
1024
|
openai_backend=openai_backend, tool_registry=tool_registry, event_handler=event_handler
|
|
1013
1025
|
)
|
|
1014
1026
|
self.socket_path = Path(socket_path)
|
|
1015
|
-
self._server: asyncio.AbstractServer
|
|
1027
|
+
self._server: Optional[asyncio.AbstractServer] = None
|
|
1016
1028
|
|
|
1017
1029
|
async def start(self) -> None:
|
|
1018
1030
|
# Most platforms enforce a short maximum length for AF_UNIX socket paths.
|
|
@@ -1445,7 +1457,7 @@ class TcpBrokerServer(_BaseBrokerServer):
|
|
|
1445
1457
|
*,
|
|
1446
1458
|
host: str = "127.0.0.1",
|
|
1447
1459
|
port: int = 0,
|
|
1448
|
-
ssl_context: ssl.SSLContext
|
|
1460
|
+
ssl_context: Optional[ssl.SSLContext] = None,
|
|
1449
1461
|
openai_backend: Optional[OpenAIChatBackend] = None,
|
|
1450
1462
|
tool_registry: Optional[HostToolRegistry] = None,
|
|
1451
1463
|
event_handler: Optional[Callable[[dict[str, Any]], None]] = None,
|
|
@@ -1460,8 +1472,8 @@ class TcpBrokerServer(_BaseBrokerServer):
|
|
|
1460
1472
|
self.host = host
|
|
1461
1473
|
self.port = port
|
|
1462
1474
|
self.ssl_context = ssl_context
|
|
1463
|
-
self.bound_port: int
|
|
1464
|
-
self._serve_task: asyncio.Task[None]
|
|
1475
|
+
self.bound_port: Optional[int] = None
|
|
1476
|
+
self._serve_task: Optional[asyncio.Task[None]] = None
|
|
1465
1477
|
|
|
1466
1478
|
async def start(self) -> None:
|
|
1467
1479
|
# Create AnyIO TCP listener (doesn't block, just binds to port)
|
tactus/core/dsl_stubs.py
CHANGED
|
@@ -31,7 +31,7 @@ Agent/Tool calls use direct variable access:
|
|
|
31
31
|
done.last_result() -- Get last tool result
|
|
32
32
|
"""
|
|
33
33
|
|
|
34
|
-
from typing import Any, Callable
|
|
34
|
+
from typing import Any, Callable, Dict, Optional
|
|
35
35
|
|
|
36
36
|
from .registry import RegistryBuilder
|
|
37
37
|
from tactus.primitives.handles import AgentHandle, ModelHandle, AgentLookup, ModelLookup
|
|
@@ -112,7 +112,7 @@ def create_dsl_stubs(
|
|
|
112
112
|
builder: RegistryBuilder,
|
|
113
113
|
tool_primitive: Any = None,
|
|
114
114
|
mock_manager: Any = None,
|
|
115
|
-
runtime_context:
|
|
115
|
+
runtime_context: Optional[Dict[str, Any]] = None,
|
|
116
116
|
) -> dict[str, Callable]:
|
|
117
117
|
"""
|
|
118
118
|
Create DSL stub functions that populate the registry.
|
|
@@ -144,7 +144,7 @@ def create_dsl_stubs(
|
|
|
144
144
|
_procedure_registry = {}
|
|
145
145
|
|
|
146
146
|
def _process_procedure_config(
|
|
147
|
-
name: str
|
|
147
|
+
name: Optional[str], config: Any, procedure_registry: Dict[str, Any]
|
|
148
148
|
):
|
|
149
149
|
"""
|
|
150
150
|
Process procedure config and register the procedure.
|
tactus/core/execution_context.py
CHANGED
|
@@ -6,7 +6,7 @@ Uses pluggable storage and HITL handlers via protocols.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
from abc import ABC, abstractmethod
|
|
9
|
-
from typing import Any, Callable
|
|
9
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
10
10
|
from datetime import datetime, timezone
|
|
11
11
|
import logging
|
|
12
12
|
import time
|
|
@@ -39,7 +39,7 @@ class ExecutionContext(ABC):
|
|
|
39
39
|
self,
|
|
40
40
|
fn: Callable[[], Any],
|
|
41
41
|
checkpoint_type: str,
|
|
42
|
-
source_info:
|
|
42
|
+
source_info: Optional[Dict[str, Any]] = None,
|
|
43
43
|
) -> Any:
|
|
44
44
|
"""
|
|
45
45
|
Execute fn with position-based checkpointing. On replay, return stored result.
|
|
@@ -59,9 +59,9 @@ class ExecutionContext(ABC):
|
|
|
59
59
|
self,
|
|
60
60
|
request_type: str,
|
|
61
61
|
message: str,
|
|
62
|
-
timeout_seconds: int
|
|
62
|
+
timeout_seconds: Optional[int],
|
|
63
63
|
default_value: Any,
|
|
64
|
-
options:
|
|
64
|
+
options: Optional[List[dict]],
|
|
65
65
|
metadata: dict,
|
|
66
66
|
) -> HITLResponse:
|
|
67
67
|
"""
|
|
@@ -121,7 +121,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
121
121
|
self,
|
|
122
122
|
procedure_id: str,
|
|
123
123
|
storage_backend: StorageBackend,
|
|
124
|
-
hitl_handler: HITLHandler
|
|
124
|
+
hitl_handler: Optional[HITLHandler] = None,
|
|
125
125
|
strict_determinism: bool = False,
|
|
126
126
|
log_handler=None,
|
|
127
127
|
):
|
|
@@ -145,14 +145,14 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
145
145
|
self._inside_checkpoint = False
|
|
146
146
|
|
|
147
147
|
# Run ID tracking for distinguishing between different executions
|
|
148
|
-
self.current_run_id: str
|
|
148
|
+
self.current_run_id: Optional[str] = None
|
|
149
149
|
|
|
150
150
|
# .tac file tracking for accurate source locations
|
|
151
|
-
self.current_tac_file: str
|
|
152
|
-
self.current_tac_content: str
|
|
151
|
+
self.current_tac_file: Optional[str] = None
|
|
152
|
+
self.current_tac_content: Optional[str] = None
|
|
153
153
|
|
|
154
154
|
# Lua sandbox reference for debug.getinfo access
|
|
155
|
-
self.lua_sandbox: Any
|
|
155
|
+
self.lua_sandbox: Optional[Any] = None
|
|
156
156
|
|
|
157
157
|
# Rich metadata for HITL notifications
|
|
158
158
|
self._initialize_run_metadata(procedure_id)
|
|
@@ -177,7 +177,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
177
177
|
"""Set the run_id for subsequent checkpoints in this execution."""
|
|
178
178
|
self.current_run_id = run_id
|
|
179
179
|
|
|
180
|
-
def set_tac_file(self, file_path: str, content: str
|
|
180
|
+
def set_tac_file(self, file_path: str, content: Optional[str] = None) -> None:
|
|
181
181
|
"""
|
|
182
182
|
Store the currently executing .tac file for accurate source location capture.
|
|
183
183
|
|
|
@@ -193,7 +193,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
193
193
|
self.lua_sandbox = lua_sandbox
|
|
194
194
|
|
|
195
195
|
def set_procedure_metadata(
|
|
196
|
-
self, procedure_name: str
|
|
196
|
+
self, procedure_name: Optional[str] = None, input_data: Any = None
|
|
197
197
|
) -> None:
|
|
198
198
|
"""
|
|
199
199
|
Set rich metadata for HITL notifications.
|
|
@@ -211,7 +211,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
211
211
|
self,
|
|
212
212
|
fn: Callable[[], Any],
|
|
213
213
|
checkpoint_type: str,
|
|
214
|
-
source_info:
|
|
214
|
+
source_info: Optional[Dict[str, Any]] = None,
|
|
215
215
|
) -> Any:
|
|
216
216
|
"""
|
|
217
217
|
Execute fn with position-based checkpointing and source tracking.
|
|
@@ -406,7 +406,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
406
406
|
|
|
407
407
|
def _get_code_context(
|
|
408
408
|
self, file_path: str, line_number: int, context_lines: int = 3
|
|
409
|
-
) -> str
|
|
409
|
+
) -> Optional[str]:
|
|
410
410
|
"""Read source file and extract surrounding lines for debugging."""
|
|
411
411
|
try:
|
|
412
412
|
with open(file_path, "r") as source_file:
|
|
@@ -421,9 +421,9 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
421
421
|
self,
|
|
422
422
|
request_type: str,
|
|
423
423
|
message: str,
|
|
424
|
-
timeout_seconds: int
|
|
424
|
+
timeout_seconds: Optional[int],
|
|
425
425
|
default_value: Any,
|
|
426
|
-
options:
|
|
426
|
+
options: Optional[List[dict]],
|
|
427
427
|
metadata: dict,
|
|
428
428
|
) -> HITLResponse:
|
|
429
429
|
"""
|
|
@@ -505,7 +505,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
505
505
|
async_procedure_handles[handle.procedure_id] = handle.to_dict()
|
|
506
506
|
self.storage.save_procedure_metadata(self.procedure_id, self.metadata)
|
|
507
507
|
|
|
508
|
-
def get_procedure_handle(self, procedure_id: str) ->
|
|
508
|
+
def get_procedure_handle(self, procedure_id: str) -> Optional[Dict[str, Any]]:
|
|
509
509
|
"""
|
|
510
510
|
Retrieve procedure handle.
|
|
511
511
|
|
|
@@ -609,7 +609,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
609
609
|
|
|
610
610
|
return run_id
|
|
611
611
|
|
|
612
|
-
def get_subject(self) -> str
|
|
612
|
+
def get_subject(self) -> Optional[str]:
|
|
613
613
|
"""
|
|
614
614
|
Return a human-readable subject line for this execution.
|
|
615
615
|
|
|
@@ -621,7 +621,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
621
621
|
return f"{self.procedure_name} (checkpoint {checkpoint_position})"
|
|
622
622
|
return f"Procedure {self.procedure_id} (checkpoint {checkpoint_position})"
|
|
623
623
|
|
|
624
|
-
def get_started_at(self) -> datetime
|
|
624
|
+
def get_started_at(self) -> Optional[datetime]:
|
|
625
625
|
"""
|
|
626
626
|
Return when this execution started.
|
|
627
627
|
|
|
@@ -630,7 +630,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
630
630
|
"""
|
|
631
631
|
return self._started_at
|
|
632
632
|
|
|
633
|
-
def get_input_summary(self) ->
|
|
633
|
+
def get_input_summary(self) -> Optional[Dict[str, Any]]:
|
|
634
634
|
"""
|
|
635
635
|
Return a summary of the initial input to this procedure.
|
|
636
636
|
|
|
@@ -647,7 +647,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
647
647
|
# Otherwise wrap it in a dict
|
|
648
648
|
return {"value": self._input_data}
|
|
649
649
|
|
|
650
|
-
def get_conversation_history(self) ->
|
|
650
|
+
def get_conversation_history(self) -> Optional[List[dict]]:
|
|
651
651
|
"""
|
|
652
652
|
Return conversation history if available.
|
|
653
653
|
|
|
@@ -658,7 +658,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
658
658
|
# in future implementations
|
|
659
659
|
return None
|
|
660
660
|
|
|
661
|
-
def get_prior_control_interactions(self) ->
|
|
661
|
+
def get_prior_control_interactions(self) -> Optional[List[dict]]:
|
|
662
662
|
"""
|
|
663
663
|
Return list of prior HITL interactions in this execution.
|
|
664
664
|
|
|
@@ -682,7 +682,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
682
682
|
|
|
683
683
|
return hitl_checkpoints if hitl_checkpoints else None
|
|
684
684
|
|
|
685
|
-
def get_lua_source_line(self) -> int
|
|
685
|
+
def get_lua_source_line(self) -> Optional[int]:
|
|
686
686
|
"""
|
|
687
687
|
Get the current source line from Lua debug.getinfo.
|
|
688
688
|
|
|
@@ -768,7 +768,7 @@ class InMemoryExecutionContext(BaseExecutionContext):
|
|
|
768
768
|
and simple CLI workflows that don't need to survive restarts.
|
|
769
769
|
"""
|
|
770
770
|
|
|
771
|
-
def __init__(self, procedure_id: str, hitl_handler: HITLHandler
|
|
771
|
+
def __init__(self, procedure_id: str, hitl_handler: Optional[HITLHandler] = None):
|
|
772
772
|
"""
|
|
773
773
|
Initialize with in-memory storage.
|
|
774
774
|
|
|
@@ -8,7 +8,7 @@ Aligned with pydantic-ai's message_history concept.
|
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
10
|
from datetime import datetime, timezone
|
|
11
|
-
from typing import Any, Optional
|
|
11
|
+
from typing import Any, Optional, Tuple
|
|
12
12
|
|
|
13
13
|
try:
|
|
14
14
|
from pydantic_ai.messages import ModelMessage
|
|
@@ -146,7 +146,7 @@ class MessageHistoryManager:
|
|
|
146
146
|
return self._apply_named_filter(messages, filter_name, filter_value)
|
|
147
147
|
|
|
148
148
|
@staticmethod
|
|
149
|
-
def _parse_filter_spec(filter_specification: Any) ->
|
|
149
|
+
def _parse_filter_spec(filter_specification: Any) -> Tuple[Optional[str], Any]:
|
|
150
150
|
if not isinstance(filter_specification, tuple) or len(filter_specification) < 2:
|
|
151
151
|
return None, None
|
|
152
152
|
|
tactus/core/output_validator.py
CHANGED
|
@@ -6,7 +6,7 @@ Enables type safety and composability for sub-agent workflows.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import logging
|
|
9
|
-
from typing import Any, Optional
|
|
9
|
+
from typing import Any, Optional, Tuple
|
|
10
10
|
|
|
11
11
|
logger = logging.getLogger(__name__)
|
|
12
12
|
|
|
@@ -78,7 +78,7 @@ class OutputValidator:
|
|
|
78
78
|
logger.debug("OutputValidator initialized with %s output fields", field_count)
|
|
79
79
|
|
|
80
80
|
@staticmethod
|
|
81
|
-
def _unwrap_result(output: Any) ->
|
|
81
|
+
def _unwrap_result(output: Any) -> Tuple[Any, Optional[Any]]:
|
|
82
82
|
from tactus.protocols.result import TactusResult
|
|
83
83
|
|
|
84
84
|
wrapped_result = output if isinstance(output, TactusResult) else None
|
|
@@ -94,7 +94,7 @@ class OutputValidator:
|
|
|
94
94
|
|
|
95
95
|
@staticmethod
|
|
96
96
|
def _wrap_validated_output(
|
|
97
|
-
wrapped_result: Any
|
|
97
|
+
wrapped_result: Optional[Any],
|
|
98
98
|
validated_payload: Any,
|
|
99
99
|
) -> Any:
|
|
100
100
|
if wrapped_result is not None:
|
|
@@ -129,7 +129,7 @@ class OutputValidator:
|
|
|
129
129
|
def _validate_without_schema(
|
|
130
130
|
self,
|
|
131
131
|
output: Any,
|
|
132
|
-
wrapped_result: Any
|
|
132
|
+
wrapped_result: Optional[Any],
|
|
133
133
|
) -> Any:
|
|
134
134
|
"""Accept any output when no schema is defined."""
|
|
135
135
|
logger.debug("No output schema defined, skipping validation")
|
|
@@ -139,7 +139,7 @@ class OutputValidator:
|
|
|
139
139
|
def _validate_scalar_schema(
|
|
140
140
|
self,
|
|
141
141
|
output: Any,
|
|
142
|
-
wrapped_result: Any
|
|
142
|
+
wrapped_result: Optional[Any],
|
|
143
143
|
) -> Any:
|
|
144
144
|
"""Validate scalar outputs (`field.string{}` etc.)."""
|
|
145
145
|
# Lua tables are not valid scalar outputs.
|
|
@@ -168,7 +168,7 @@ class OutputValidator:
|
|
|
168
168
|
def _validate_structured_schema(
|
|
169
169
|
self,
|
|
170
170
|
output: Any,
|
|
171
|
-
wrapped_result: Any
|
|
171
|
+
wrapped_result: Optional[Any],
|
|
172
172
|
) -> Any:
|
|
173
173
|
"""Validate dict/table outputs against a schema."""
|
|
174
174
|
if hasattr(output, "items") or isinstance(output, dict):
|