tactus 0.35.1__py3-none-any.whl → 0.37.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tactus/__init__.py +1 -1
- tactus/adapters/channels/base.py +20 -2
- tactus/adapters/channels/broker.py +1 -0
- tactus/adapters/channels/host.py +3 -1
- tactus/adapters/channels/ipc.py +18 -3
- tactus/adapters/channels/sse.py +13 -5
- tactus/adapters/control_loop.py +44 -30
- tactus/adapters/mcp_manager.py +24 -7
- tactus/backends/http_backend.py +2 -2
- tactus/backends/pytorch_backend.py +2 -2
- tactus/broker/client.py +3 -3
- tactus/broker/server.py +17 -5
- tactus/core/dsl_stubs.py +3 -3
- tactus/core/execution_context.py +32 -27
- tactus/core/lua_sandbox.py +42 -34
- tactus/core/message_history_manager.py +51 -28
- tactus/core/output_validator.py +65 -51
- tactus/core/registry.py +29 -29
- tactus/core/runtime.py +69 -61
- tactus/dspy/broker_lm.py +13 -7
- tactus/dspy/config.py +7 -4
- tactus/ide/server.py +63 -33
- tactus/primitives/host.py +19 -16
- tactus/primitives/message_history.py +11 -14
- tactus/primitives/model.py +1 -1
- tactus/primitives/procedure.py +11 -8
- tactus/primitives/session.py +9 -9
- tactus/primitives/state.py +2 -2
- tactus/primitives/tool_handle.py +27 -24
- tactus/sandbox/container_runner.py +11 -6
- tactus/testing/context.py +6 -6
- tactus/testing/evaluation_runner.py +5 -5
- tactus/testing/mock_hitl.py +2 -2
- tactus/testing/models.py +2 -0
- tactus/testing/steps/builtin.py +2 -2
- tactus/testing/test_runner.py +6 -4
- tactus/utils/asyncio_helpers.py +2 -1
- tactus/utils/safe_libraries.py +2 -2
- {tactus-0.35.1.dist-info → tactus-0.37.0.dist-info}/METADATA +11 -5
- {tactus-0.35.1.dist-info → tactus-0.37.0.dist-info}/RECORD +43 -43
- {tactus-0.35.1.dist-info → tactus-0.37.0.dist-info}/WHEEL +0 -0
- {tactus-0.35.1.dist-info → tactus-0.37.0.dist-info}/entry_points.txt +0 -0
- {tactus-0.35.1.dist-info → tactus-0.37.0.dist-info}/licenses/LICENSE +0 -0
tactus/core/execution_context.py
CHANGED
|
@@ -6,7 +6,7 @@ Uses pluggable storage and HITL handlers via protocols.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
from abc import ABC, abstractmethod
|
|
9
|
-
from typing import Any, Callable
|
|
9
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
10
10
|
from datetime import datetime, timezone
|
|
11
11
|
import logging
|
|
12
12
|
import time
|
|
@@ -39,7 +39,7 @@ class ExecutionContext(ABC):
|
|
|
39
39
|
self,
|
|
40
40
|
fn: Callable[[], Any],
|
|
41
41
|
checkpoint_type: str,
|
|
42
|
-
source_info:
|
|
42
|
+
source_info: Optional[Dict[str, Any]] = None,
|
|
43
43
|
) -> Any:
|
|
44
44
|
"""
|
|
45
45
|
Execute fn with position-based checkpointing. On replay, return stored result.
|
|
@@ -59,9 +59,9 @@ class ExecutionContext(ABC):
|
|
|
59
59
|
self,
|
|
60
60
|
request_type: str,
|
|
61
61
|
message: str,
|
|
62
|
-
timeout_seconds: int
|
|
62
|
+
timeout_seconds: Optional[int],
|
|
63
63
|
default_value: Any,
|
|
64
|
-
options:
|
|
64
|
+
options: Optional[List[dict]],
|
|
65
65
|
metadata: dict,
|
|
66
66
|
) -> HITLResponse:
|
|
67
67
|
"""
|
|
@@ -121,7 +121,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
121
121
|
self,
|
|
122
122
|
procedure_id: str,
|
|
123
123
|
storage_backend: StorageBackend,
|
|
124
|
-
hitl_handler: HITLHandler
|
|
124
|
+
hitl_handler: Optional[HITLHandler] = None,
|
|
125
125
|
strict_determinism: bool = False,
|
|
126
126
|
log_handler=None,
|
|
127
127
|
):
|
|
@@ -145,21 +145,26 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
145
145
|
self._inside_checkpoint = False
|
|
146
146
|
|
|
147
147
|
# Run ID tracking for distinguishing between different executions
|
|
148
|
-
self.current_run_id: str
|
|
148
|
+
self.current_run_id: Optional[str] = None
|
|
149
149
|
|
|
150
150
|
# .tac file tracking for accurate source locations
|
|
151
|
-
self.current_tac_file: str
|
|
152
|
-
self.current_tac_content: str
|
|
151
|
+
self.current_tac_file: Optional[str] = None
|
|
152
|
+
self.current_tac_content: Optional[str] = None
|
|
153
153
|
|
|
154
154
|
# Lua sandbox reference for debug.getinfo access
|
|
155
|
-
self.lua_sandbox: Any
|
|
155
|
+
self.lua_sandbox: Optional[Any] = None
|
|
156
156
|
|
|
157
157
|
# Rich metadata for HITL notifications
|
|
158
|
-
self.
|
|
159
|
-
self.
|
|
160
|
-
self._started_at: datetime = datetime.now(timezone.utc)
|
|
161
|
-
self._input_data: Any = None
|
|
158
|
+
self._initialize_run_metadata(procedure_id)
|
|
159
|
+
self._load_and_reset_metadata(procedure_id)
|
|
162
160
|
|
|
161
|
+
def _initialize_run_metadata(self, procedure_id: str) -> None:
|
|
162
|
+
self.procedure_name = procedure_id
|
|
163
|
+
self.invocation_id = str(uuid.uuid4())
|
|
164
|
+
self._started_at = datetime.now(timezone.utc)
|
|
165
|
+
self._input_data = None
|
|
166
|
+
|
|
167
|
+
def _load_and_reset_metadata(self, procedure_id: str) -> None:
|
|
163
168
|
# Load procedure metadata (contains execution_log and replay_index)
|
|
164
169
|
self.metadata = self.storage.load_procedure_metadata(procedure_id)
|
|
165
170
|
|
|
@@ -172,7 +177,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
172
177
|
"""Set the run_id for subsequent checkpoints in this execution."""
|
|
173
178
|
self.current_run_id = run_id
|
|
174
179
|
|
|
175
|
-
def set_tac_file(self, file_path: str, content: str
|
|
180
|
+
def set_tac_file(self, file_path: str, content: Optional[str] = None) -> None:
|
|
176
181
|
"""
|
|
177
182
|
Store the currently executing .tac file for accurate source location capture.
|
|
178
183
|
|
|
@@ -188,7 +193,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
188
193
|
self.lua_sandbox = lua_sandbox
|
|
189
194
|
|
|
190
195
|
def set_procedure_metadata(
|
|
191
|
-
self, procedure_name: str
|
|
196
|
+
self, procedure_name: Optional[str] = None, input_data: Any = None
|
|
192
197
|
) -> None:
|
|
193
198
|
"""
|
|
194
199
|
Set rich metadata for HITL notifications.
|
|
@@ -206,7 +211,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
206
211
|
self,
|
|
207
212
|
fn: Callable[[], Any],
|
|
208
213
|
checkpoint_type: str,
|
|
209
|
-
source_info:
|
|
214
|
+
source_info: Optional[Dict[str, Any]] = None,
|
|
210
215
|
) -> Any:
|
|
211
216
|
"""
|
|
212
217
|
Execute fn with position-based checkpointing and source tracking.
|
|
@@ -401,7 +406,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
401
406
|
|
|
402
407
|
def _get_code_context(
|
|
403
408
|
self, file_path: str, line_number: int, context_lines: int = 3
|
|
404
|
-
) -> str
|
|
409
|
+
) -> Optional[str]:
|
|
405
410
|
"""Read source file and extract surrounding lines for debugging."""
|
|
406
411
|
try:
|
|
407
412
|
with open(file_path, "r") as source_file:
|
|
@@ -416,9 +421,9 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
416
421
|
self,
|
|
417
422
|
request_type: str,
|
|
418
423
|
message: str,
|
|
419
|
-
timeout_seconds: int
|
|
424
|
+
timeout_seconds: Optional[int],
|
|
420
425
|
default_value: Any,
|
|
421
|
-
options:
|
|
426
|
+
options: Optional[List[dict]],
|
|
422
427
|
metadata: dict,
|
|
423
428
|
) -> HITLResponse:
|
|
424
429
|
"""
|
|
@@ -500,7 +505,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
500
505
|
async_procedure_handles[handle.procedure_id] = handle.to_dict()
|
|
501
506
|
self.storage.save_procedure_metadata(self.procedure_id, self.metadata)
|
|
502
507
|
|
|
503
|
-
def get_procedure_handle(self, procedure_id: str) ->
|
|
508
|
+
def get_procedure_handle(self, procedure_id: str) -> Optional[Dict[str, Any]]:
|
|
504
509
|
"""
|
|
505
510
|
Retrieve procedure handle.
|
|
506
511
|
|
|
@@ -604,7 +609,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
604
609
|
|
|
605
610
|
return run_id
|
|
606
611
|
|
|
607
|
-
def get_subject(self) -> str
|
|
612
|
+
def get_subject(self) -> Optional[str]:
|
|
608
613
|
"""
|
|
609
614
|
Return a human-readable subject line for this execution.
|
|
610
615
|
|
|
@@ -616,7 +621,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
616
621
|
return f"{self.procedure_name} (checkpoint {checkpoint_position})"
|
|
617
622
|
return f"Procedure {self.procedure_id} (checkpoint {checkpoint_position})"
|
|
618
623
|
|
|
619
|
-
def get_started_at(self) -> datetime
|
|
624
|
+
def get_started_at(self) -> Optional[datetime]:
|
|
620
625
|
"""
|
|
621
626
|
Return when this execution started.
|
|
622
627
|
|
|
@@ -625,7 +630,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
625
630
|
"""
|
|
626
631
|
return self._started_at
|
|
627
632
|
|
|
628
|
-
def get_input_summary(self) ->
|
|
633
|
+
def get_input_summary(self) -> Optional[Dict[str, Any]]:
|
|
629
634
|
"""
|
|
630
635
|
Return a summary of the initial input to this procedure.
|
|
631
636
|
|
|
@@ -642,7 +647,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
642
647
|
# Otherwise wrap it in a dict
|
|
643
648
|
return {"value": self._input_data}
|
|
644
649
|
|
|
645
|
-
def get_conversation_history(self) ->
|
|
650
|
+
def get_conversation_history(self) -> Optional[List[dict]]:
|
|
646
651
|
"""
|
|
647
652
|
Return conversation history if available.
|
|
648
653
|
|
|
@@ -653,7 +658,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
653
658
|
# in future implementations
|
|
654
659
|
return None
|
|
655
660
|
|
|
656
|
-
def get_prior_control_interactions(self) ->
|
|
661
|
+
def get_prior_control_interactions(self) -> Optional[List[dict]]:
|
|
657
662
|
"""
|
|
658
663
|
Return list of prior HITL interactions in this execution.
|
|
659
664
|
|
|
@@ -677,7 +682,7 @@ class BaseExecutionContext(ExecutionContext):
|
|
|
677
682
|
|
|
678
683
|
return hitl_checkpoints if hitl_checkpoints else None
|
|
679
684
|
|
|
680
|
-
def get_lua_source_line(self) -> int
|
|
685
|
+
def get_lua_source_line(self) -> Optional[int]:
|
|
681
686
|
"""
|
|
682
687
|
Get the current source line from Lua debug.getinfo.
|
|
683
688
|
|
|
@@ -763,7 +768,7 @@ class InMemoryExecutionContext(BaseExecutionContext):
|
|
|
763
768
|
and simple CLI workflows that don't need to survive restarts.
|
|
764
769
|
"""
|
|
765
770
|
|
|
766
|
-
def __init__(self, procedure_id: str, hitl_handler: HITLHandler
|
|
771
|
+
def __init__(self, procedure_id: str, hitl_handler: Optional[HITLHandler] = None):
|
|
767
772
|
"""
|
|
768
773
|
Initialize with in-memory storage.
|
|
769
774
|
|
tactus/core/lua_sandbox.py
CHANGED
|
@@ -241,7 +241,7 @@ class LuaSandbox:
|
|
|
241
241
|
"""Setup safe global functions and utilities."""
|
|
242
242
|
# Keep safe standard library functions
|
|
243
243
|
# (These are already available by default, just documenting them)
|
|
244
|
-
|
|
244
|
+
safe_global_symbols = {
|
|
245
245
|
# Math
|
|
246
246
|
"math", # Math library (will be replaced with safe version if context available)
|
|
247
247
|
"tonumber", # Convert to number
|
|
@@ -264,54 +264,62 @@ class LuaSandbox:
|
|
|
264
264
|
}
|
|
265
265
|
|
|
266
266
|
# Just log what's available - no need to explicitly set
|
|
267
|
-
logger.debug("Safe Lua functions available: %s", ", ".join(
|
|
267
|
+
logger.debug("Safe Lua functions available: %s", ", ".join(safe_global_symbols))
|
|
268
268
|
|
|
269
269
|
# Replace math and os libraries with safe versions if context available
|
|
270
270
|
if self.execution_context is not None:
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
create_safe_os_library,
|
|
274
|
-
)
|
|
271
|
+
self._install_context_safe_libraries()
|
|
272
|
+
return # Skip default os.date setup below
|
|
275
273
|
|
|
276
|
-
|
|
277
|
-
return self.execution_context
|
|
274
|
+
self._install_fallback_os_date()
|
|
278
275
|
|
|
279
|
-
|
|
280
|
-
|
|
276
|
+
def _install_context_safe_libraries(self) -> None:
|
|
277
|
+
"""Install safe math and os libraries based on execution context."""
|
|
278
|
+
from tactus.utils.safe_libraries import (
|
|
279
|
+
create_safe_math_library,
|
|
280
|
+
create_safe_os_library,
|
|
281
|
+
)
|
|
281
282
|
|
|
282
|
-
|
|
283
|
-
|
|
283
|
+
def get_execution_context() -> Any:
|
|
284
|
+
return self.execution_context
|
|
284
285
|
|
|
285
|
-
|
|
286
|
-
|
|
286
|
+
safe_math_dict = create_safe_math_library(get_execution_context, self.strict_determinism)
|
|
287
|
+
safe_os_dict = create_safe_os_library(get_execution_context, self.strict_determinism)
|
|
287
288
|
|
|
288
|
-
|
|
289
|
-
|
|
289
|
+
safe_math_table = self._dict_to_lua_table(safe_math_dict)
|
|
290
|
+
safe_os_table = self._dict_to_lua_table(safe_os_dict)
|
|
290
291
|
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
292
|
+
self.lua.globals()["math"] = safe_math_table
|
|
293
|
+
self.lua.globals()["os"] = safe_os_table
|
|
294
|
+
|
|
295
|
+
logger.debug("Installed safe math and os libraries with determinism checking")
|
|
296
|
+
|
|
297
|
+
def _install_fallback_os_date(self) -> None:
|
|
298
|
+
"""Install a safe os.date() fallback when no execution context is available."""
|
|
299
|
+
safe_os_table = self._build_fallback_os_table()
|
|
300
|
+
self.lua.globals()["os"] = safe_os_table
|
|
301
|
+
logger.debug("Added safe os.date() function")
|
|
294
302
|
|
|
295
|
-
|
|
303
|
+
def _build_fallback_os_table(self) -> Any:
|
|
304
|
+
"""Build a Lua os table with a safe date() implementation."""
|
|
305
|
+
from datetime import datetime, timezone
|
|
306
|
+
|
|
307
|
+
def safe_date(format_string: Optional[str] = None) -> str:
|
|
296
308
|
"""Safe implementation of os.date() for timestamp generation."""
|
|
297
|
-
now = datetime.
|
|
298
|
-
if
|
|
309
|
+
now = datetime.now(timezone.utc)
|
|
310
|
+
if format_string is None:
|
|
299
311
|
# Return default format like Lua's os.date()
|
|
300
312
|
return now.strftime("%a %b %d %H:%M:%S %Y")
|
|
301
|
-
|
|
313
|
+
if format_string == "%Y-%m-%dT%H:%M:%SZ":
|
|
302
314
|
# ISO 8601 format
|
|
303
315
|
return now.strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
# Create safe os table with only date function
|
|
312
|
-
safe_os = self.lua.table(date=safe_date)
|
|
313
|
-
self.lua.globals()["os"] = safe_os
|
|
314
|
-
logger.debug("Added safe os.date() function")
|
|
316
|
+
# Support Python strftime formats
|
|
317
|
+
try:
|
|
318
|
+
return now.strftime(format_string)
|
|
319
|
+
except Exception: # noqa: E722
|
|
320
|
+
return now.strftime("%a %b %d %H:%M:%S %Y")
|
|
321
|
+
|
|
322
|
+
return self.lua.table(date=safe_date)
|
|
315
323
|
|
|
316
324
|
def setup_assignment_interception(self, callback: Any) -> None:
|
|
317
325
|
"""
|
|
@@ -8,7 +8,7 @@ Aligned with pydantic-ai's message_history concept.
|
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
10
|
from datetime import datetime, timezone
|
|
11
|
-
from typing import Any, Optional
|
|
11
|
+
from typing import Any, Optional, Tuple
|
|
12
12
|
|
|
13
13
|
try:
|
|
14
14
|
from pydantic_ai.messages import ModelMessage
|
|
@@ -136,37 +136,60 @@ class MessageHistoryManager:
|
|
|
136
136
|
print(f"Warning: Filter function failed: {exception}")
|
|
137
137
|
return messages
|
|
138
138
|
|
|
139
|
-
|
|
140
|
-
if
|
|
139
|
+
filter_name, filter_value = self._parse_filter_spec(filter_specification)
|
|
140
|
+
if filter_name is None:
|
|
141
141
|
return messages
|
|
142
142
|
|
|
143
|
-
filter_name
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
return
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
143
|
+
if filter_name == "compose":
|
|
144
|
+
return self._apply_composed_filters(messages, filter_value, context)
|
|
145
|
+
|
|
146
|
+
return self._apply_named_filter(messages, filter_name, filter_value)
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
149
|
+
def _parse_filter_spec(filter_specification: Any) -> Tuple[Optional[str], Any]:
|
|
150
|
+
if not isinstance(filter_specification, tuple) or len(filter_specification) < 2:
|
|
151
|
+
return None, None
|
|
152
|
+
|
|
153
|
+
return filter_specification[0], filter_specification[1]
|
|
154
|
+
|
|
155
|
+
def _apply_composed_filters(
|
|
156
|
+
self,
|
|
157
|
+
messages: list[ModelMessage],
|
|
158
|
+
filter_steps: Any,
|
|
159
|
+
context: Optional[Any],
|
|
160
|
+
) -> list[ModelMessage]:
|
|
161
|
+
filtered_messages = messages
|
|
162
|
+
for filter_step in filter_steps:
|
|
163
|
+
filtered_messages = self._apply_filter(filtered_messages, filter_step, context)
|
|
164
|
+
return filtered_messages
|
|
165
|
+
|
|
166
|
+
def _apply_named_filter(
|
|
167
|
+
self,
|
|
168
|
+
messages: list[ModelMessage],
|
|
169
|
+
filter_name: str,
|
|
170
|
+
filter_value: Any,
|
|
171
|
+
) -> list[ModelMessage]:
|
|
172
|
+
filter_function = self._filter_dispatch.get(filter_name)
|
|
173
|
+
if filter_function is None:
|
|
168
174
|
return messages
|
|
169
175
|
|
|
176
|
+
if filter_name == "system_prefix":
|
|
177
|
+
return filter_function(messages)
|
|
178
|
+
|
|
179
|
+
return filter_function(messages, filter_value)
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def _filter_dispatch(self) -> dict[str, Any]:
|
|
183
|
+
return {
|
|
184
|
+
"last_n": self._filter_last_n,
|
|
185
|
+
"first_n": self._filter_first_n,
|
|
186
|
+
"token_budget": self._filter_by_token_budget,
|
|
187
|
+
"head_tokens": self._filter_head_tokens,
|
|
188
|
+
"tail_tokens": self._filter_tail_tokens,
|
|
189
|
+
"by_role": self._filter_by_role,
|
|
190
|
+
"system_prefix": self._filter_system_prefix,
|
|
191
|
+
}
|
|
192
|
+
|
|
170
193
|
def _filter_last_n(
|
|
171
194
|
self,
|
|
172
195
|
messages: list[ModelMessage],
|
tactus/core/output_validator.py
CHANGED
|
@@ -6,7 +6,7 @@ Enables type safety and composability for sub-agent workflows.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import logging
|
|
9
|
-
from typing import Any, Optional
|
|
9
|
+
from typing import Any, Optional, Tuple
|
|
10
10
|
|
|
11
11
|
logger = logging.getLogger(__name__)
|
|
12
12
|
|
|
@@ -29,13 +29,14 @@ class OutputValidator:
|
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
31
|
# Type mapping from YAML to Python
|
|
32
|
-
|
|
32
|
+
SCHEMA_TYPE_TO_PYTHON_TYPE = {
|
|
33
33
|
"string": str,
|
|
34
34
|
"number": (int, float),
|
|
35
35
|
"boolean": bool,
|
|
36
36
|
"object": dict,
|
|
37
37
|
"array": list,
|
|
38
38
|
}
|
|
39
|
+
TYPE_MAP = SCHEMA_TYPE_TO_PYTHON_TYPE
|
|
39
40
|
|
|
40
41
|
@classmethod
|
|
41
42
|
def _is_scalar_schema(cls, schema: Any) -> bool:
|
|
@@ -77,7 +78,7 @@ class OutputValidator:
|
|
|
77
78
|
logger.debug("OutputValidator initialized with %s output fields", field_count)
|
|
78
79
|
|
|
79
80
|
@staticmethod
|
|
80
|
-
def _unwrap_result(output: Any) ->
|
|
81
|
+
def _unwrap_result(output: Any) -> Tuple[Any, Optional[Any]]:
|
|
81
82
|
from tactus.protocols.result import TactusResult
|
|
82
83
|
|
|
83
84
|
wrapped_result = output if isinstance(output, TactusResult) else None
|
|
@@ -91,6 +92,15 @@ class OutputValidator:
|
|
|
91
92
|
return dict(output.items())
|
|
92
93
|
return output
|
|
93
94
|
|
|
95
|
+
@staticmethod
|
|
96
|
+
def _wrap_validated_output(
|
|
97
|
+
wrapped_result: Optional[Any],
|
|
98
|
+
validated_payload: Any,
|
|
99
|
+
) -> Any:
|
|
100
|
+
if wrapped_result is not None:
|
|
101
|
+
return wrapped_result.model_copy(update={"output": validated_payload})
|
|
102
|
+
return validated_payload
|
|
103
|
+
|
|
94
104
|
def validate(self, output: Any) -> Any:
|
|
95
105
|
"""
|
|
96
106
|
Validate workflow output against schema.
|
|
@@ -108,49 +118,63 @@ class OutputValidator:
|
|
|
108
118
|
# while preserving the wrapper (so callers can still access usage/cost/etc.).
|
|
109
119
|
output, wrapped_result = self._unwrap_result(output)
|
|
110
120
|
|
|
111
|
-
# If no schema defined, accept any output
|
|
112
121
|
if not self.schema:
|
|
113
|
-
|
|
114
|
-
validated_payload = self._normalize_unstructured_output(output)
|
|
122
|
+
return self._validate_without_schema(output, wrapped_result)
|
|
115
123
|
|
|
116
|
-
if wrapped_result is not None:
|
|
117
|
-
return wrapped_result.model_copy(update={"output": validated_payload})
|
|
118
|
-
return validated_payload
|
|
119
|
-
|
|
120
|
-
# Scalar output schema: `output = field.string{...}` etc.
|
|
121
124
|
if self._is_scalar_schema(self.schema):
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
125
|
+
return self._validate_scalar_schema(output, wrapped_result)
|
|
126
|
+
|
|
127
|
+
return self._validate_structured_schema(output, wrapped_result)
|
|
128
|
+
|
|
129
|
+
def _validate_without_schema(
|
|
130
|
+
self,
|
|
131
|
+
output: Any,
|
|
132
|
+
wrapped_result: Optional[Any],
|
|
133
|
+
) -> Any:
|
|
134
|
+
"""Accept any output when no schema is defined."""
|
|
135
|
+
logger.debug("No output schema defined, skipping validation")
|
|
136
|
+
validated_payload = self._normalize_unstructured_output(output)
|
|
137
|
+
return self._wrap_validated_output(wrapped_result, validated_payload)
|
|
138
|
+
|
|
139
|
+
def _validate_scalar_schema(
|
|
140
|
+
self,
|
|
141
|
+
output: Any,
|
|
142
|
+
wrapped_result: Optional[Any],
|
|
143
|
+
) -> Any:
|
|
144
|
+
"""Validate scalar outputs (`field.string{}` etc.)."""
|
|
145
|
+
# Lua tables are not valid scalar outputs.
|
|
146
|
+
if hasattr(output, "items") and not isinstance(output, dict):
|
|
147
|
+
output = dict(output.items())
|
|
148
|
+
|
|
149
|
+
is_required = self.schema.get("required", False)
|
|
150
|
+
if output is None and not is_required:
|
|
151
|
+
return None
|
|
152
|
+
|
|
153
|
+
expected_type = self.schema.get("type")
|
|
154
|
+
if expected_type and not self._check_type(output, expected_type):
|
|
155
|
+
raise OutputValidationError(
|
|
156
|
+
f"Output should be {expected_type}, got {type(output).__name__}"
|
|
157
|
+
)
|
|
129
158
|
|
|
130
|
-
|
|
131
|
-
|
|
159
|
+
if "enum" in self.schema and self.schema["enum"]:
|
|
160
|
+
allowed_values = self.schema["enum"]
|
|
161
|
+
if output not in allowed_values:
|
|
132
162
|
raise OutputValidationError(
|
|
133
|
-
f"Output
|
|
163
|
+
f"Output has invalid value '{output}'. Allowed values: {allowed_values}"
|
|
134
164
|
)
|
|
135
165
|
|
|
136
|
-
|
|
137
|
-
allowed_values = self.schema["enum"]
|
|
138
|
-
if output not in allowed_values:
|
|
139
|
-
raise OutputValidationError(
|
|
140
|
-
f"Output has invalid value '{output}'. Allowed values: {allowed_values}"
|
|
141
|
-
)
|
|
142
|
-
|
|
143
|
-
validated_payload = output
|
|
144
|
-
if wrapped_result is not None:
|
|
145
|
-
return wrapped_result.model_copy(update={"output": validated_payload})
|
|
146
|
-
return validated_payload
|
|
166
|
+
return self._wrap_validated_output(wrapped_result, output)
|
|
147
167
|
|
|
148
|
-
|
|
168
|
+
def _validate_structured_schema(
|
|
169
|
+
self,
|
|
170
|
+
output: Any,
|
|
171
|
+
wrapped_result: Optional[Any],
|
|
172
|
+
) -> Any:
|
|
173
|
+
"""Validate dict/table outputs against a schema."""
|
|
149
174
|
if hasattr(output, "items") or isinstance(output, dict):
|
|
150
175
|
logger.debug("Converting Lua tables to Python dicts recursively")
|
|
151
176
|
output = self._convert_lua_tables(output)
|
|
152
177
|
|
|
153
|
-
# Output must be a dict/table
|
|
154
178
|
if not isinstance(output, dict):
|
|
155
179
|
raise OutputValidationError(
|
|
156
180
|
f"Output must be an object/table, got {type(output).__name__}"
|
|
@@ -159,7 +183,6 @@ class OutputValidator:
|
|
|
159
183
|
validation_errors: list[str] = []
|
|
160
184
|
validated_output: dict[str, Any] = {}
|
|
161
185
|
|
|
162
|
-
# Check required fields and validate types
|
|
163
186
|
for field_name, field_def in self.schema.items():
|
|
164
187
|
if not isinstance(field_def, dict) or "type" not in field_def:
|
|
165
188
|
validation_errors.append(
|
|
@@ -167,28 +190,23 @@ class OutputValidator:
|
|
|
167
190
|
f"Use field.{field_def.get('type', 'string')}{{}} instead."
|
|
168
191
|
)
|
|
169
192
|
continue
|
|
170
|
-
is_required = bool(field_def.get("required", False))
|
|
171
193
|
|
|
194
|
+
is_required = bool(field_def.get("required", False))
|
|
172
195
|
if is_required and field_name not in output:
|
|
173
196
|
validation_errors.append(f"Required field '{field_name}' is missing")
|
|
174
197
|
continue
|
|
175
198
|
|
|
176
|
-
# Skip validation if field not present and not required
|
|
177
199
|
if field_name not in output:
|
|
178
200
|
continue
|
|
179
201
|
|
|
180
202
|
value = output[field_name]
|
|
181
|
-
|
|
182
|
-
# Type checking
|
|
183
203
|
expected_type = field_def.get("type")
|
|
184
|
-
if expected_type:
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
)
|
|
204
|
+
if expected_type and not self._check_type(value, expected_type):
|
|
205
|
+
actual_type = type(value).__name__
|
|
206
|
+
validation_errors.append(
|
|
207
|
+
f"Field '{field_name}' should be {expected_type}, got {actual_type}"
|
|
208
|
+
)
|
|
190
209
|
|
|
191
|
-
# Enum validation
|
|
192
210
|
if "enum" in field_def and field_def["enum"]:
|
|
193
211
|
allowed_values = field_def["enum"]
|
|
194
212
|
if value not in allowed_values:
|
|
@@ -197,10 +215,8 @@ class OutputValidator:
|
|
|
197
215
|
f"Allowed values: {allowed_values}"
|
|
198
216
|
)
|
|
199
217
|
|
|
200
|
-
# Add to validated output (only declared fields)
|
|
201
218
|
validated_output[field_name] = value
|
|
202
219
|
|
|
203
|
-
# Filter undeclared fields (only return declared fields)
|
|
204
220
|
for field_name in output:
|
|
205
221
|
if field_name not in self.schema:
|
|
206
222
|
logger.debug("Filtering undeclared field '%s' from output", field_name)
|
|
@@ -210,9 +226,7 @@ class OutputValidator:
|
|
|
210
226
|
raise OutputValidationError(error_message)
|
|
211
227
|
|
|
212
228
|
logger.info("Output validation passed for %s fields", len(validated_output))
|
|
213
|
-
|
|
214
|
-
return wrapped_result.model_copy(update={"output": validated_output})
|
|
215
|
-
return validated_output
|
|
229
|
+
return self._wrap_validated_output(wrapped_result, validated_output)
|
|
216
230
|
|
|
217
231
|
def _check_type(self, value: Any, expected_type: str) -> bool:
|
|
218
232
|
"""
|