tactus 0.35.1__py3-none-any.whl → 0.37.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. tactus/__init__.py +1 -1
  2. tactus/adapters/channels/base.py +20 -2
  3. tactus/adapters/channels/broker.py +1 -0
  4. tactus/adapters/channels/host.py +3 -1
  5. tactus/adapters/channels/ipc.py +18 -3
  6. tactus/adapters/channels/sse.py +13 -5
  7. tactus/adapters/control_loop.py +44 -30
  8. tactus/adapters/mcp_manager.py +24 -7
  9. tactus/backends/http_backend.py +2 -2
  10. tactus/backends/pytorch_backend.py +2 -2
  11. tactus/broker/client.py +3 -3
  12. tactus/broker/server.py +17 -5
  13. tactus/core/dsl_stubs.py +3 -3
  14. tactus/core/execution_context.py +32 -27
  15. tactus/core/lua_sandbox.py +42 -34
  16. tactus/core/message_history_manager.py +51 -28
  17. tactus/core/output_validator.py +65 -51
  18. tactus/core/registry.py +29 -29
  19. tactus/core/runtime.py +69 -61
  20. tactus/dspy/broker_lm.py +13 -7
  21. tactus/dspy/config.py +7 -4
  22. tactus/ide/server.py +63 -33
  23. tactus/primitives/host.py +19 -16
  24. tactus/primitives/message_history.py +11 -14
  25. tactus/primitives/model.py +1 -1
  26. tactus/primitives/procedure.py +11 -8
  27. tactus/primitives/session.py +9 -9
  28. tactus/primitives/state.py +2 -2
  29. tactus/primitives/tool_handle.py +27 -24
  30. tactus/sandbox/container_runner.py +11 -6
  31. tactus/testing/context.py +6 -6
  32. tactus/testing/evaluation_runner.py +5 -5
  33. tactus/testing/mock_hitl.py +2 -2
  34. tactus/testing/models.py +2 -0
  35. tactus/testing/steps/builtin.py +2 -2
  36. tactus/testing/test_runner.py +6 -4
  37. tactus/utils/asyncio_helpers.py +2 -1
  38. tactus/utils/safe_libraries.py +2 -2
  39. {tactus-0.35.1.dist-info → tactus-0.37.0.dist-info}/METADATA +11 -5
  40. {tactus-0.35.1.dist-info → tactus-0.37.0.dist-info}/RECORD +43 -43
  41. {tactus-0.35.1.dist-info → tactus-0.37.0.dist-info}/WHEEL +0 -0
  42. {tactus-0.35.1.dist-info → tactus-0.37.0.dist-info}/entry_points.txt +0 -0
  43. {tactus-0.35.1.dist-info → tactus-0.37.0.dist-info}/licenses/LICENSE +0 -0
@@ -6,7 +6,7 @@ Uses pluggable storage and HITL handlers via protocols.
6
6
  """
7
7
 
8
8
  from abc import ABC, abstractmethod
9
- from typing import Any, Callable
9
+ from typing import Any, Callable, Dict, List, Optional
10
10
  from datetime import datetime, timezone
11
11
  import logging
12
12
  import time
@@ -39,7 +39,7 @@ class ExecutionContext(ABC):
39
39
  self,
40
40
  fn: Callable[[], Any],
41
41
  checkpoint_type: str,
42
- source_info: dict[str, Any] | None = None,
42
+ source_info: Optional[Dict[str, Any]] = None,
43
43
  ) -> Any:
44
44
  """
45
45
  Execute fn with position-based checkpointing. On replay, return stored result.
@@ -59,9 +59,9 @@ class ExecutionContext(ABC):
59
59
  self,
60
60
  request_type: str,
61
61
  message: str,
62
- timeout_seconds: int | None,
62
+ timeout_seconds: Optional[int],
63
63
  default_value: Any,
64
- options: list[dict] | None,
64
+ options: Optional[List[dict]],
65
65
  metadata: dict,
66
66
  ) -> HITLResponse:
67
67
  """
@@ -121,7 +121,7 @@ class BaseExecutionContext(ExecutionContext):
121
121
  self,
122
122
  procedure_id: str,
123
123
  storage_backend: StorageBackend,
124
- hitl_handler: HITLHandler | None = None,
124
+ hitl_handler: Optional[HITLHandler] = None,
125
125
  strict_determinism: bool = False,
126
126
  log_handler=None,
127
127
  ):
@@ -145,21 +145,26 @@ class BaseExecutionContext(ExecutionContext):
145
145
  self._inside_checkpoint = False
146
146
 
147
147
  # Run ID tracking for distinguishing between different executions
148
- self.current_run_id: str | None = None
148
+ self.current_run_id: Optional[str] = None
149
149
 
150
150
  # .tac file tracking for accurate source locations
151
- self.current_tac_file: str | None = None
152
- self.current_tac_content: str | None = None
151
+ self.current_tac_file: Optional[str] = None
152
+ self.current_tac_content: Optional[str] = None
153
153
 
154
154
  # Lua sandbox reference for debug.getinfo access
155
- self.lua_sandbox: Any | None = None
155
+ self.lua_sandbox: Optional[Any] = None
156
156
 
157
157
  # Rich metadata for HITL notifications
158
- self.procedure_name: str = procedure_id # Use procedure_id as default name
159
- self.invocation_id: str = str(uuid.uuid4())
160
- self._started_at: datetime = datetime.now(timezone.utc)
161
- self._input_data: Any = None
158
+ self._initialize_run_metadata(procedure_id)
159
+ self._load_and_reset_metadata(procedure_id)
162
160
 
161
+ def _initialize_run_metadata(self, procedure_id: str) -> None:
162
+ self.procedure_name = procedure_id
163
+ self.invocation_id = str(uuid.uuid4())
164
+ self._started_at = datetime.now(timezone.utc)
165
+ self._input_data = None
166
+
167
+ def _load_and_reset_metadata(self, procedure_id: str) -> None:
163
168
  # Load procedure metadata (contains execution_log and replay_index)
164
169
  self.metadata = self.storage.load_procedure_metadata(procedure_id)
165
170
 
@@ -172,7 +177,7 @@ class BaseExecutionContext(ExecutionContext):
172
177
  """Set the run_id for subsequent checkpoints in this execution."""
173
178
  self.current_run_id = run_id
174
179
 
175
- def set_tac_file(self, file_path: str, content: str | None = None) -> None:
180
+ def set_tac_file(self, file_path: str, content: Optional[str] = None) -> None:
176
181
  """
177
182
  Store the currently executing .tac file for accurate source location capture.
178
183
 
@@ -188,7 +193,7 @@ class BaseExecutionContext(ExecutionContext):
188
193
  self.lua_sandbox = lua_sandbox
189
194
 
190
195
  def set_procedure_metadata(
191
- self, procedure_name: str | None = None, input_data: Any = None
196
+ self, procedure_name: Optional[str] = None, input_data: Any = None
192
197
  ) -> None:
193
198
  """
194
199
  Set rich metadata for HITL notifications.
@@ -206,7 +211,7 @@ class BaseExecutionContext(ExecutionContext):
206
211
  self,
207
212
  fn: Callable[[], Any],
208
213
  checkpoint_type: str,
209
- source_info: dict[str, Any] | None = None,
214
+ source_info: Optional[Dict[str, Any]] = None,
210
215
  ) -> Any:
211
216
  """
212
217
  Execute fn with position-based checkpointing and source tracking.
@@ -401,7 +406,7 @@ class BaseExecutionContext(ExecutionContext):
401
406
 
402
407
  def _get_code_context(
403
408
  self, file_path: str, line_number: int, context_lines: int = 3
404
- ) -> str | None:
409
+ ) -> Optional[str]:
405
410
  """Read source file and extract surrounding lines for debugging."""
406
411
  try:
407
412
  with open(file_path, "r") as source_file:
@@ -416,9 +421,9 @@ class BaseExecutionContext(ExecutionContext):
416
421
  self,
417
422
  request_type: str,
418
423
  message: str,
419
- timeout_seconds: int | None,
424
+ timeout_seconds: Optional[int],
420
425
  default_value: Any,
421
- options: list[dict] | None,
426
+ options: Optional[List[dict]],
422
427
  metadata: dict,
423
428
  ) -> HITLResponse:
424
429
  """
@@ -500,7 +505,7 @@ class BaseExecutionContext(ExecutionContext):
500
505
  async_procedure_handles[handle.procedure_id] = handle.to_dict()
501
506
  self.storage.save_procedure_metadata(self.procedure_id, self.metadata)
502
507
 
503
- def get_procedure_handle(self, procedure_id: str) -> dict[str, Any] | None:
508
+ def get_procedure_handle(self, procedure_id: str) -> Optional[Dict[str, Any]]:
504
509
  """
505
510
  Retrieve procedure handle.
506
511
 
@@ -604,7 +609,7 @@ class BaseExecutionContext(ExecutionContext):
604
609
 
605
610
  return run_id
606
611
 
607
- def get_subject(self) -> str | None:
612
+ def get_subject(self) -> Optional[str]:
608
613
  """
609
614
  Return a human-readable subject line for this execution.
610
615
 
@@ -616,7 +621,7 @@ class BaseExecutionContext(ExecutionContext):
616
621
  return f"{self.procedure_name} (checkpoint {checkpoint_position})"
617
622
  return f"Procedure {self.procedure_id} (checkpoint {checkpoint_position})"
618
623
 
619
- def get_started_at(self) -> datetime | None:
624
+ def get_started_at(self) -> Optional[datetime]:
620
625
  """
621
626
  Return when this execution started.
622
627
 
@@ -625,7 +630,7 @@ class BaseExecutionContext(ExecutionContext):
625
630
  """
626
631
  return self._started_at
627
632
 
628
- def get_input_summary(self) -> dict[str, Any] | None:
633
+ def get_input_summary(self) -> Optional[Dict[str, Any]]:
629
634
  """
630
635
  Return a summary of the initial input to this procedure.
631
636
 
@@ -642,7 +647,7 @@ class BaseExecutionContext(ExecutionContext):
642
647
  # Otherwise wrap it in a dict
643
648
  return {"value": self._input_data}
644
649
 
645
- def get_conversation_history(self) -> list[dict] | None:
650
+ def get_conversation_history(self) -> Optional[List[dict]]:
646
651
  """
647
652
  Return conversation history if available.
648
653
 
@@ -653,7 +658,7 @@ class BaseExecutionContext(ExecutionContext):
653
658
  # in future implementations
654
659
  return None
655
660
 
656
- def get_prior_control_interactions(self) -> list[dict] | None:
661
+ def get_prior_control_interactions(self) -> Optional[List[dict]]:
657
662
  """
658
663
  Return list of prior HITL interactions in this execution.
659
664
 
@@ -677,7 +682,7 @@ class BaseExecutionContext(ExecutionContext):
677
682
 
678
683
  return hitl_checkpoints if hitl_checkpoints else None
679
684
 
680
- def get_lua_source_line(self) -> int | None:
685
+ def get_lua_source_line(self) -> Optional[int]:
681
686
  """
682
687
  Get the current source line from Lua debug.getinfo.
683
688
 
@@ -763,7 +768,7 @@ class InMemoryExecutionContext(BaseExecutionContext):
763
768
  and simple CLI workflows that don't need to survive restarts.
764
769
  """
765
770
 
766
- def __init__(self, procedure_id: str, hitl_handler: HITLHandler | None = None):
771
+ def __init__(self, procedure_id: str, hitl_handler: Optional[HITLHandler] = None):
767
772
  """
768
773
  Initialize with in-memory storage.
769
774
 
@@ -241,7 +241,7 @@ class LuaSandbox:
241
241
  """Setup safe global functions and utilities."""
242
242
  # Keep safe standard library functions
243
243
  # (These are already available by default, just documenting them)
244
- safe_functions = {
244
+ safe_global_symbols = {
245
245
  # Math
246
246
  "math", # Math library (will be replaced with safe version if context available)
247
247
  "tonumber", # Convert to number
@@ -264,54 +264,62 @@ class LuaSandbox:
264
264
  }
265
265
 
266
266
  # Just log what's available - no need to explicitly set
267
- logger.debug("Safe Lua functions available: %s", ", ".join(safe_functions))
267
+ logger.debug("Safe Lua functions available: %s", ", ".join(safe_global_symbols))
268
268
 
269
269
  # Replace math and os libraries with safe versions if context available
270
270
  if self.execution_context is not None:
271
- from tactus.utils.safe_libraries import (
272
- create_safe_math_library,
273
- create_safe_os_library,
274
- )
271
+ self._install_context_safe_libraries()
272
+ return # Skip default os.date setup below
275
273
 
276
- def get_context():
277
- return self.execution_context
274
+ self._install_fallback_os_date()
278
275
 
279
- safe_math_dict = create_safe_math_library(get_context, self.strict_determinism)
280
- safe_os_dict = create_safe_os_library(get_context, self.strict_determinism)
276
+ def _install_context_safe_libraries(self) -> None:
277
+ """Install safe math and os libraries based on execution context."""
278
+ from tactus.utils.safe_libraries import (
279
+ create_safe_math_library,
280
+ create_safe_os_library,
281
+ )
281
282
 
282
- safe_math_table = self._dict_to_lua_table(safe_math_dict)
283
- safe_os_table = self._dict_to_lua_table(safe_os_dict)
283
+ def get_execution_context() -> Any:
284
+ return self.execution_context
284
285
 
285
- self.lua.globals()["math"] = safe_math_table
286
- self.lua.globals()["os"] = safe_os_table
286
+ safe_math_dict = create_safe_math_library(get_execution_context, self.strict_determinism)
287
+ safe_os_dict = create_safe_os_library(get_execution_context, self.strict_determinism)
287
288
 
288
- logger.debug("Installed safe math and os libraries with determinism checking")
289
- return # Skip default os.date setup below
289
+ safe_math_table = self._dict_to_lua_table(safe_math_dict)
290
+ safe_os_table = self._dict_to_lua_table(safe_os_dict)
290
291
 
291
- # Add safe subset of os module (only date function for timestamps)
292
- # This is a fallback when no execution context is available (testing/REPL)
293
- from datetime import datetime
292
+ self.lua.globals()["math"] = safe_math_table
293
+ self.lua.globals()["os"] = safe_os_table
294
+
295
+ logger.debug("Installed safe math and os libraries with determinism checking")
296
+
297
+ def _install_fallback_os_date(self) -> None:
298
+ """Install a safe os.date() fallback when no execution context is available."""
299
+ safe_os_table = self._build_fallback_os_table()
300
+ self.lua.globals()["os"] = safe_os_table
301
+ logger.debug("Added safe os.date() function")
294
302
 
295
- def safe_date(format_str=None):
303
+ def _build_fallback_os_table(self) -> Any:
304
+ """Build a Lua os table with a safe date() implementation."""
305
+ from datetime import datetime, timezone
306
+
307
+ def safe_date(format_string: Optional[str] = None) -> str:
296
308
  """Safe implementation of os.date() for timestamp generation."""
297
- now = datetime.utcnow()
298
- if format_str is None:
309
+ now = datetime.now(timezone.utc)
310
+ if format_string is None:
299
311
  # Return default format like Lua's os.date()
300
312
  return now.strftime("%a %b %d %H:%M:%S %Y")
301
- elif format_str == "%Y-%m-%dT%H:%M:%SZ":
313
+ if format_string == "%Y-%m-%dT%H:%M:%SZ":
302
314
  # ISO 8601 format
303
315
  return now.strftime("%Y-%m-%dT%H:%M:%SZ")
304
- else:
305
- # Support Python strftime formats
306
- try:
307
- return now.strftime(format_str)
308
- except Exception: # noqa: E722
309
- return now.strftime("%a %b %d %H:%M:%S %Y")
310
-
311
- # Create safe os table with only date function
312
- safe_os = self.lua.table(date=safe_date)
313
- self.lua.globals()["os"] = safe_os
314
- logger.debug("Added safe os.date() function")
316
+ # Support Python strftime formats
317
+ try:
318
+ return now.strftime(format_string)
319
+ except Exception: # noqa: E722
320
+ return now.strftime("%a %b %d %H:%M:%S %Y")
321
+
322
+ return self.lua.table(date=safe_date)
315
323
 
316
324
  def setup_assignment_interception(self, callback: Any) -> None:
317
325
  """
@@ -8,7 +8,7 @@ Aligned with pydantic-ai's message_history concept.
8
8
  """
9
9
 
10
10
  from datetime import datetime, timezone
11
- from typing import Any, Optional
11
+ from typing import Any, Optional, Tuple
12
12
 
13
13
  try:
14
14
  from pydantic_ai.messages import ModelMessage
@@ -136,37 +136,60 @@ class MessageHistoryManager:
136
136
  print(f"Warning: Filter function failed: {exception}")
137
137
  return messages
138
138
 
139
- # Otherwise it's a tuple (filter_type, filter_arg)
140
- if not isinstance(filter_specification, tuple) or len(filter_specification) < 2:
139
+ filter_name, filter_value = self._parse_filter_spec(filter_specification)
140
+ if filter_name is None:
141
141
  return messages
142
142
 
143
- filter_name = filter_specification[0]
144
- filter_value = filter_specification[1]
145
-
146
- if filter_name == "last_n":
147
- return self._filter_last_n(messages, filter_value)
148
- elif filter_name == "first_n":
149
- return self._filter_first_n(messages, filter_value)
150
- elif filter_name == "token_budget":
151
- return self._filter_by_token_budget(messages, filter_value)
152
- elif filter_name == "head_tokens":
153
- return self._filter_head_tokens(messages, filter_value)
154
- elif filter_name == "tail_tokens":
155
- return self._filter_tail_tokens(messages, filter_value)
156
- elif filter_name == "by_role":
157
- return self._filter_by_role(messages, filter_value)
158
- elif filter_name == "system_prefix":
159
- return self._filter_system_prefix(messages)
160
- elif filter_name == "compose":
161
- # Apply multiple filters in sequence
162
- filtered_messages = messages
163
- for filter_step in filter_value:
164
- filtered_messages = self._apply_filter(filtered_messages, filter_step, context)
165
- return filtered_messages
166
- else:
167
- # Unknown filter type, return unfiltered
143
+ if filter_name == "compose":
144
+ return self._apply_composed_filters(messages, filter_value, context)
145
+
146
+ return self._apply_named_filter(messages, filter_name, filter_value)
147
+
148
+ @staticmethod
149
+ def _parse_filter_spec(filter_specification: Any) -> Tuple[Optional[str], Any]:
150
+ if not isinstance(filter_specification, tuple) or len(filter_specification) < 2:
151
+ return None, None
152
+
153
+ return filter_specification[0], filter_specification[1]
154
+
155
+ def _apply_composed_filters(
156
+ self,
157
+ messages: list[ModelMessage],
158
+ filter_steps: Any,
159
+ context: Optional[Any],
160
+ ) -> list[ModelMessage]:
161
+ filtered_messages = messages
162
+ for filter_step in filter_steps:
163
+ filtered_messages = self._apply_filter(filtered_messages, filter_step, context)
164
+ return filtered_messages
165
+
166
+ def _apply_named_filter(
167
+ self,
168
+ messages: list[ModelMessage],
169
+ filter_name: str,
170
+ filter_value: Any,
171
+ ) -> list[ModelMessage]:
172
+ filter_function = self._filter_dispatch.get(filter_name)
173
+ if filter_function is None:
168
174
  return messages
169
175
 
176
+ if filter_name == "system_prefix":
177
+ return filter_function(messages)
178
+
179
+ return filter_function(messages, filter_value)
180
+
181
+ @property
182
+ def _filter_dispatch(self) -> dict[str, Any]:
183
+ return {
184
+ "last_n": self._filter_last_n,
185
+ "first_n": self._filter_first_n,
186
+ "token_budget": self._filter_by_token_budget,
187
+ "head_tokens": self._filter_head_tokens,
188
+ "tail_tokens": self._filter_tail_tokens,
189
+ "by_role": self._filter_by_role,
190
+ "system_prefix": self._filter_system_prefix,
191
+ }
192
+
170
193
  def _filter_last_n(
171
194
  self,
172
195
  messages: list[ModelMessage],
@@ -6,7 +6,7 @@ Enables type safety and composability for sub-agent workflows.
6
6
  """
7
7
 
8
8
  import logging
9
- from typing import Any, Optional
9
+ from typing import Any, Optional, Tuple
10
10
 
11
11
  logger = logging.getLogger(__name__)
12
12
 
@@ -29,13 +29,14 @@ class OutputValidator:
29
29
  """
30
30
 
31
31
  # Type mapping from YAML to Python
32
- TYPE_MAP = {
32
+ SCHEMA_TYPE_TO_PYTHON_TYPE = {
33
33
  "string": str,
34
34
  "number": (int, float),
35
35
  "boolean": bool,
36
36
  "object": dict,
37
37
  "array": list,
38
38
  }
39
+ TYPE_MAP = SCHEMA_TYPE_TO_PYTHON_TYPE
39
40
 
40
41
  @classmethod
41
42
  def _is_scalar_schema(cls, schema: Any) -> bool:
@@ -77,7 +78,7 @@ class OutputValidator:
77
78
  logger.debug("OutputValidator initialized with %s output fields", field_count)
78
79
 
79
80
  @staticmethod
80
- def _unwrap_result(output: Any) -> tuple[Any, Any | None]:
81
+ def _unwrap_result(output: Any) -> Tuple[Any, Optional[Any]]:
81
82
  from tactus.protocols.result import TactusResult
82
83
 
83
84
  wrapped_result = output if isinstance(output, TactusResult) else None
@@ -91,6 +92,15 @@ class OutputValidator:
91
92
  return dict(output.items())
92
93
  return output
93
94
 
95
+ @staticmethod
96
+ def _wrap_validated_output(
97
+ wrapped_result: Optional[Any],
98
+ validated_payload: Any,
99
+ ) -> Any:
100
+ if wrapped_result is not None:
101
+ return wrapped_result.model_copy(update={"output": validated_payload})
102
+ return validated_payload
103
+
94
104
  def validate(self, output: Any) -> Any:
95
105
  """
96
106
  Validate workflow output against schema.
@@ -108,49 +118,63 @@ class OutputValidator:
108
118
  # while preserving the wrapper (so callers can still access usage/cost/etc.).
109
119
  output, wrapped_result = self._unwrap_result(output)
110
120
 
111
- # If no schema defined, accept any output
112
121
  if not self.schema:
113
- logger.debug("No output schema defined, skipping validation")
114
- validated_payload = self._normalize_unstructured_output(output)
122
+ return self._validate_without_schema(output, wrapped_result)
115
123
 
116
- if wrapped_result is not None:
117
- return wrapped_result.model_copy(update={"output": validated_payload})
118
- return validated_payload
119
-
120
- # Scalar output schema: `output = field.string{...}` etc.
121
124
  if self._is_scalar_schema(self.schema):
122
- # Lua tables are not valid scalar outputs.
123
- if hasattr(output, "items") and not isinstance(output, dict):
124
- output = dict(output.items())
125
-
126
- is_required = self.schema.get("required", False)
127
- if output is None and not is_required:
128
- return None
125
+ return self._validate_scalar_schema(output, wrapped_result)
126
+
127
+ return self._validate_structured_schema(output, wrapped_result)
128
+
129
+ def _validate_without_schema(
130
+ self,
131
+ output: Any,
132
+ wrapped_result: Optional[Any],
133
+ ) -> Any:
134
+ """Accept any output when no schema is defined."""
135
+ logger.debug("No output schema defined, skipping validation")
136
+ validated_payload = self._normalize_unstructured_output(output)
137
+ return self._wrap_validated_output(wrapped_result, validated_payload)
138
+
139
+ def _validate_scalar_schema(
140
+ self,
141
+ output: Any,
142
+ wrapped_result: Optional[Any],
143
+ ) -> Any:
144
+ """Validate scalar outputs (`field.string{}` etc.)."""
145
+ # Lua tables are not valid scalar outputs.
146
+ if hasattr(output, "items") and not isinstance(output, dict):
147
+ output = dict(output.items())
148
+
149
+ is_required = self.schema.get("required", False)
150
+ if output is None and not is_required:
151
+ return None
152
+
153
+ expected_type = self.schema.get("type")
154
+ if expected_type and not self._check_type(output, expected_type):
155
+ raise OutputValidationError(
156
+ f"Output should be {expected_type}, got {type(output).__name__}"
157
+ )
129
158
 
130
- expected_type = self.schema.get("type")
131
- if expected_type and not self._check_type(output, expected_type):
159
+ if "enum" in self.schema and self.schema["enum"]:
160
+ allowed_values = self.schema["enum"]
161
+ if output not in allowed_values:
132
162
  raise OutputValidationError(
133
- f"Output should be {expected_type}, got {type(output).__name__}"
163
+ f"Output has invalid value '{output}'. Allowed values: {allowed_values}"
134
164
  )
135
165
 
136
- if "enum" in self.schema and self.schema["enum"]:
137
- allowed_values = self.schema["enum"]
138
- if output not in allowed_values:
139
- raise OutputValidationError(
140
- f"Output has invalid value '{output}'. Allowed values: {allowed_values}"
141
- )
142
-
143
- validated_payload = output
144
- if wrapped_result is not None:
145
- return wrapped_result.model_copy(update={"output": validated_payload})
146
- return validated_payload
166
+ return self._wrap_validated_output(wrapped_result, output)
147
167
 
148
- # Convert Lua tables to dicts recursively
168
+ def _validate_structured_schema(
169
+ self,
170
+ output: Any,
171
+ wrapped_result: Optional[Any],
172
+ ) -> Any:
173
+ """Validate dict/table outputs against a schema."""
149
174
  if hasattr(output, "items") or isinstance(output, dict):
150
175
  logger.debug("Converting Lua tables to Python dicts recursively")
151
176
  output = self._convert_lua_tables(output)
152
177
 
153
- # Output must be a dict/table
154
178
  if not isinstance(output, dict):
155
179
  raise OutputValidationError(
156
180
  f"Output must be an object/table, got {type(output).__name__}"
@@ -159,7 +183,6 @@ class OutputValidator:
159
183
  validation_errors: list[str] = []
160
184
  validated_output: dict[str, Any] = {}
161
185
 
162
- # Check required fields and validate types
163
186
  for field_name, field_def in self.schema.items():
164
187
  if not isinstance(field_def, dict) or "type" not in field_def:
165
188
  validation_errors.append(
@@ -167,28 +190,23 @@ class OutputValidator:
167
190
  f"Use field.{field_def.get('type', 'string')}{{}} instead."
168
191
  )
169
192
  continue
170
- is_required = bool(field_def.get("required", False))
171
193
 
194
+ is_required = bool(field_def.get("required", False))
172
195
  if is_required and field_name not in output:
173
196
  validation_errors.append(f"Required field '{field_name}' is missing")
174
197
  continue
175
198
 
176
- # Skip validation if field not present and not required
177
199
  if field_name not in output:
178
200
  continue
179
201
 
180
202
  value = output[field_name]
181
-
182
- # Type checking
183
203
  expected_type = field_def.get("type")
184
- if expected_type:
185
- if not self._check_type(value, expected_type):
186
- actual_type = type(value).__name__
187
- validation_errors.append(
188
- f"Field '{field_name}' should be {expected_type}, got {actual_type}"
189
- )
204
+ if expected_type and not self._check_type(value, expected_type):
205
+ actual_type = type(value).__name__
206
+ validation_errors.append(
207
+ f"Field '{field_name}' should be {expected_type}, got {actual_type}"
208
+ )
190
209
 
191
- # Enum validation
192
210
  if "enum" in field_def and field_def["enum"]:
193
211
  allowed_values = field_def["enum"]
194
212
  if value not in allowed_values:
@@ -197,10 +215,8 @@ class OutputValidator:
197
215
  f"Allowed values: {allowed_values}"
198
216
  )
199
217
 
200
- # Add to validated output (only declared fields)
201
218
  validated_output[field_name] = value
202
219
 
203
- # Filter undeclared fields (only return declared fields)
204
220
  for field_name in output:
205
221
  if field_name not in self.schema:
206
222
  logger.debug("Filtering undeclared field '%s' from output", field_name)
@@ -210,9 +226,7 @@ class OutputValidator:
210
226
  raise OutputValidationError(error_message)
211
227
 
212
228
  logger.info("Output validation passed for %s fields", len(validated_output))
213
- if wrapped_result is not None:
214
- return wrapped_result.model_copy(update={"output": validated_output})
215
- return validated_output
229
+ return self._wrap_validated_output(wrapped_result, validated_output)
216
230
 
217
231
  def _check_type(self, value: Any, expected_type: str) -> bool:
218
232
  """