tactus 0.35.0__py3-none-any.whl → 0.36.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tactus/__init__.py CHANGED
@@ -5,7 +5,7 @@ Tactus provides a declarative workflow engine for AI agents with pluggable
5
5
  backends for storage, HITL, and chat recording.
6
6
  """
7
7
 
8
- __version__ = "0.35.0"
8
+ __version__ = "0.36.0"
9
9
 
10
10
  # Core exports
11
11
  from tactus.core.runtime import TactusRuntime
@@ -102,7 +102,7 @@ class HostControlChannel(InProcessChannel):
102
102
 
103
103
  # Start background thread for input collection
104
104
  self._input_thread = threading.Thread(
105
- target=self._input_thread_main,
105
+ target=self._collect_input_in_thread,
106
106
  args=(request,),
107
107
  daemon=True,
108
108
  )
@@ -142,7 +142,7 @@ class HostControlChannel(InProcessChannel):
142
142
  if self._input_thread and self._input_thread.is_alive():
143
143
  self._input_thread.join(timeout=1.0)
144
144
 
145
- def _input_thread_main(self, request: ControlRequest) -> None:
145
+ def _collect_input_in_thread(self, request: ControlRequest) -> None:
146
146
  """
147
147
  Background thread main function.
148
148
 
@@ -154,32 +154,38 @@ class HostControlChannel(InProcessChannel):
154
154
  """
155
155
  try:
156
156
  # Collect input (may block)
157
- response_value = self._prompt_for_input(request)
157
+ user_response_value = self._prompt_for_input(request)
158
158
 
159
159
  # Check if cancelled while waiting
160
160
  if self._cancel_event.is_set():
161
161
  return
162
162
 
163
- if response_value is not None:
164
- # Create response and push to queue
165
- response = ControlResponse(
166
- request_id=request.request_id,
167
- value=response_value,
168
- responded_at=datetime.now(timezone.utc),
169
- timed_out=False,
170
- channel_id=self.channel_id,
171
- )
172
-
173
- # Push thread-safe
174
- if self._event_loop:
175
- self.push_response_threadsafe(response, self._event_loop)
176
- else:
177
- self.push_response(response)
163
+ if user_response_value is not None:
164
+ response = self._build_response(request, user_response_value)
165
+ self._push_response_from_thread(response)
178
166
 
179
167
  except Exception as error:
180
168
  if not self._cancel_event.is_set():
181
169
  logger.error("%s: input error: %s", self.channel_id, error)
182
170
 
171
+ def _input_thread_main(self, request: ControlRequest) -> None:
172
+ self._collect_input_in_thread(request)
173
+
174
+ def _build_response(self, request: ControlRequest, response_value: Any) -> ControlResponse:
175
+ return ControlResponse(
176
+ request_id=request.request_id,
177
+ value=response_value,
178
+ responded_at=datetime.now(timezone.utc),
179
+ timed_out=False,
180
+ channel_id=self.channel_id,
181
+ )
182
+
183
+ def _push_response_from_thread(self, response: ControlResponse) -> None:
184
+ if self._event_loop:
185
+ self.push_response_threadsafe(response, self._event_loop)
186
+ else:
187
+ self.push_response(response)
188
+
183
189
  @abstractmethod
184
190
  def _display_request(self, request: ControlRequest) -> None:
185
191
  """
@@ -248,7 +248,14 @@ class SSEControlChannel(InProcessChannel):
248
248
  """
249
249
  logger.info("%s: received response for %s", self.channel_id, request_id)
250
250
 
251
- response = ControlResponse(
251
+ response = self._build_response(request_id, value)
252
+
253
+ # Push to queue from sync context (Flask thread)
254
+ # Get the running event loop and schedule the put operation
255
+ self._enqueue_response_from_sync_context(request_id, response)
256
+
257
+ def _build_response(self, request_id: str, value: Any) -> ControlResponse:
258
+ return ControlResponse(
252
259
  request_id=request_id,
253
260
  value=value,
254
261
  responded_at=datetime.now(timezone.utc),
@@ -256,15 +263,14 @@ class SSEControlChannel(InProcessChannel):
256
263
  channel_id=self.channel_id,
257
264
  )
258
265
 
259
- # Push to queue from sync context (Flask thread)
260
- # Get the running event loop and schedule the put operation
266
+ def _enqueue_response_from_sync_context(
267
+ self, request_id: str, response: ControlResponse
268
+ ) -> None:
261
269
  try:
262
270
  event_loop = asyncio.get_event_loop()
263
271
  if event_loop.is_running():
264
- # Schedule the coroutine in the running loop
265
272
  asyncio.run_coroutine_threadsafe(self._response_queue.put(response), event_loop)
266
273
  else:
267
- # If no loop is running, use put_nowait (shouldn't happen)
268
274
  self._response_queue.put_nowait(response)
269
275
  except Exception as error:
270
276
  logger.error(
@@ -201,37 +201,12 @@ class ControlLoopHandler:
201
201
  message[:50],
202
202
  )
203
203
 
204
- # Run the async request flow
205
- # Check if we're already in an async context
206
- try:
207
- event_loop = asyncio.get_running_loop()
208
- if event_loop.is_closed():
209
- raise RuntimeError("Running event loop is closed")
204
+ # Run the async request flow.
205
+ running_event_loop = self._get_running_event_loop()
206
+ if running_event_loop is not None:
207
+ return self._run_request_in_running_loop(running_event_loop, request)
210
208
 
211
- # Already in async context - create task and run it
212
- # This shouldn't normally happen since request_interaction is sync
213
- import nest_asyncio
214
-
215
- nest_asyncio.apply()
216
- return event_loop.run_until_complete(self._request_interaction_async(request))
217
- except RuntimeError:
218
- # Not in async context - create a temporary event loop.
219
- previous_event_loop: asyncio.AbstractEventLoop | None = None
220
- try:
221
- previous_event_loop = asyncio.get_event_loop()
222
- except RuntimeError:
223
- previous_event_loop = None
224
- else:
225
- if getattr(previous_event_loop, "is_closed", lambda: False)():
226
- previous_event_loop = None
227
-
228
- event_loop = asyncio.new_event_loop()
229
- try:
230
- asyncio.set_event_loop(event_loop)
231
- return event_loop.run_until_complete(self._request_interaction_async(request))
232
- finally:
233
- event_loop.close()
234
- asyncio.set_event_loop(previous_event_loop)
209
+ return self._run_request_in_new_loop(request)
235
210
 
236
211
  async def _request_interaction_async(self, request: ControlRequest) -> ControlResponse:
237
212
  """
@@ -312,6 +287,45 @@ class ControlLoopHandler:
312
287
 
313
288
  raise ProcedureWaitingForHuman(request.procedure_id, request.request_id)
314
289
 
290
+ def _get_running_event_loop(self) -> Optional[asyncio.AbstractEventLoop]:
291
+ try:
292
+ event_loop = asyncio.get_running_loop()
293
+ except RuntimeError:
294
+ return None
295
+
296
+ if event_loop.is_closed():
297
+ return None
298
+ return event_loop
299
+
300
+ def _run_request_in_running_loop(
301
+ self, event_loop: asyncio.AbstractEventLoop, request: ControlRequest
302
+ ) -> ControlResponse:
303
+ # Already in async context - create task and run it.
304
+ # This shouldn't normally happen since request_interaction is sync.
305
+ import nest_asyncio
306
+
307
+ nest_asyncio.apply()
308
+ return event_loop.run_until_complete(self._request_interaction_async(request))
309
+
310
+ def _run_request_in_new_loop(self, request: ControlRequest) -> ControlResponse:
311
+ # Not in async context - create a temporary event loop.
312
+ previous_event_loop: Optional[asyncio.AbstractEventLoop] = None
313
+ try:
314
+ previous_event_loop = asyncio.get_event_loop()
315
+ except RuntimeError:
316
+ previous_event_loop = None
317
+ else:
318
+ if getattr(previous_event_loop, "is_closed", lambda: False)():
319
+ previous_event_loop = None
320
+
321
+ event_loop = asyncio.new_event_loop()
322
+ try:
323
+ asyncio.set_event_loop(event_loop)
324
+ return event_loop.run_until_complete(self._request_interaction_async(request))
325
+ finally:
326
+ event_loop.close()
327
+ asyncio.set_event_loop(previous_event_loop)
328
+
315
329
  async def _fanout(
316
330
  self,
317
331
  request: ControlRequest,
@@ -155,11 +155,16 @@ class BaseExecutionContext(ExecutionContext):
155
155
  self.lua_sandbox: Any | None = None
156
156
 
157
157
  # Rich metadata for HITL notifications
158
- self.procedure_name: str = procedure_id # Use procedure_id as default name
159
- self.invocation_id: str = str(uuid.uuid4())
160
- self._started_at: datetime = datetime.now(timezone.utc)
161
- self._input_data: Any = None
158
+ self._initialize_run_metadata(procedure_id)
159
+ self._load_and_reset_metadata(procedure_id)
162
160
 
161
+ def _initialize_run_metadata(self, procedure_id: str) -> None:
162
+ self.procedure_name = procedure_id
163
+ self.invocation_id = str(uuid.uuid4())
164
+ self._started_at = datetime.now(timezone.utc)
165
+ self._input_data = None
166
+
167
+ def _load_and_reset_metadata(self, procedure_id: str) -> None:
163
168
  # Load procedure metadata (contains execution_log and replay_index)
164
169
  self.metadata = self.storage.load_procedure_metadata(procedure_id)
165
170
 
@@ -241,7 +241,7 @@ class LuaSandbox:
241
241
  """Setup safe global functions and utilities."""
242
242
  # Keep safe standard library functions
243
243
  # (These are already available by default, just documenting them)
244
- safe_functions = {
244
+ safe_global_symbols = {
245
245
  # Math
246
246
  "math", # Math library (will be replaced with safe version if context available)
247
247
  "tonumber", # Convert to number
@@ -264,54 +264,62 @@ class LuaSandbox:
264
264
  }
265
265
 
266
266
  # Just log what's available - no need to explicitly set
267
- logger.debug("Safe Lua functions available: %s", ", ".join(safe_functions))
267
+ logger.debug("Safe Lua functions available: %s", ", ".join(safe_global_symbols))
268
268
 
269
269
  # Replace math and os libraries with safe versions if context available
270
270
  if self.execution_context is not None:
271
- from tactus.utils.safe_libraries import (
272
- create_safe_math_library,
273
- create_safe_os_library,
274
- )
271
+ self._install_context_safe_libraries()
272
+ return # Skip default os.date setup below
275
273
 
276
- def get_context():
277
- return self.execution_context
274
+ self._install_fallback_os_date()
278
275
 
279
- safe_math_dict = create_safe_math_library(get_context, self.strict_determinism)
280
- safe_os_dict = create_safe_os_library(get_context, self.strict_determinism)
276
+ def _install_context_safe_libraries(self) -> None:
277
+ """Install safe math and os libraries based on execution context."""
278
+ from tactus.utils.safe_libraries import (
279
+ create_safe_math_library,
280
+ create_safe_os_library,
281
+ )
281
282
 
282
- safe_math_table = self._dict_to_lua_table(safe_math_dict)
283
- safe_os_table = self._dict_to_lua_table(safe_os_dict)
283
+ def get_execution_context() -> Any:
284
+ return self.execution_context
284
285
 
285
- self.lua.globals()["math"] = safe_math_table
286
- self.lua.globals()["os"] = safe_os_table
286
+ safe_math_dict = create_safe_math_library(get_execution_context, self.strict_determinism)
287
+ safe_os_dict = create_safe_os_library(get_execution_context, self.strict_determinism)
287
288
 
288
- logger.debug("Installed safe math and os libraries with determinism checking")
289
- return # Skip default os.date setup below
289
+ safe_math_table = self._dict_to_lua_table(safe_math_dict)
290
+ safe_os_table = self._dict_to_lua_table(safe_os_dict)
290
291
 
291
- # Add safe subset of os module (only date function for timestamps)
292
- # This is a fallback when no execution context is available (testing/REPL)
293
- from datetime import datetime
292
+ self.lua.globals()["math"] = safe_math_table
293
+ self.lua.globals()["os"] = safe_os_table
294
+
295
+ logger.debug("Installed safe math and os libraries with determinism checking")
296
+
297
+ def _install_fallback_os_date(self) -> None:
298
+ """Install a safe os.date() fallback when no execution context is available."""
299
+ safe_os_table = self._build_fallback_os_table()
300
+ self.lua.globals()["os"] = safe_os_table
301
+ logger.debug("Added safe os.date() function")
294
302
 
295
- def safe_date(format_str=None):
303
+ def _build_fallback_os_table(self) -> Any:
304
+ """Build a Lua os table with a safe date() implementation."""
305
+ from datetime import datetime, timezone
306
+
307
+ def safe_date(format_string: Optional[str] = None) -> str:
296
308
  """Safe implementation of os.date() for timestamp generation."""
297
- now = datetime.utcnow()
298
- if format_str is None:
309
+ now = datetime.now(timezone.utc)
310
+ if format_string is None:
299
311
  # Return default format like Lua's os.date()
300
312
  return now.strftime("%a %b %d %H:%M:%S %Y")
301
- elif format_str == "%Y-%m-%dT%H:%M:%SZ":
313
+ if format_string == "%Y-%m-%dT%H:%M:%SZ":
302
314
  # ISO 8601 format
303
315
  return now.strftime("%Y-%m-%dT%H:%M:%SZ")
304
- else:
305
- # Support Python strftime formats
306
- try:
307
- return now.strftime(format_str)
308
- except Exception: # noqa: E722
309
- return now.strftime("%a %b %d %H:%M:%S %Y")
310
-
311
- # Create safe os table with only date function
312
- safe_os = self.lua.table(date=safe_date)
313
- self.lua.globals()["os"] = safe_os
314
- logger.debug("Added safe os.date() function")
316
+ # Support Python strftime formats
317
+ try:
318
+ return now.strftime(format_string)
319
+ except Exception: # noqa: E722
320
+ return now.strftime("%a %b %d %H:%M:%S %Y")
321
+
322
+ return self.lua.table(date=safe_date)
315
323
 
316
324
  def setup_assignment_interception(self, callback: Any) -> None:
317
325
  """
@@ -136,37 +136,60 @@ class MessageHistoryManager:
136
136
  print(f"Warning: Filter function failed: {exception}")
137
137
  return messages
138
138
 
139
- # Otherwise it's a tuple (filter_type, filter_arg)
140
- if not isinstance(filter_specification, tuple) or len(filter_specification) < 2:
139
+ filter_name, filter_value = self._parse_filter_spec(filter_specification)
140
+ if filter_name is None:
141
141
  return messages
142
142
 
143
- filter_name = filter_specification[0]
144
- filter_value = filter_specification[1]
145
-
146
- if filter_name == "last_n":
147
- return self._filter_last_n(messages, filter_value)
148
- elif filter_name == "first_n":
149
- return self._filter_first_n(messages, filter_value)
150
- elif filter_name == "token_budget":
151
- return self._filter_by_token_budget(messages, filter_value)
152
- elif filter_name == "head_tokens":
153
- return self._filter_head_tokens(messages, filter_value)
154
- elif filter_name == "tail_tokens":
155
- return self._filter_tail_tokens(messages, filter_value)
156
- elif filter_name == "by_role":
157
- return self._filter_by_role(messages, filter_value)
158
- elif filter_name == "system_prefix":
159
- return self._filter_system_prefix(messages)
160
- elif filter_name == "compose":
161
- # Apply multiple filters in sequence
162
- filtered_messages = messages
163
- for filter_step in filter_value:
164
- filtered_messages = self._apply_filter(filtered_messages, filter_step, context)
165
- return filtered_messages
166
- else:
167
- # Unknown filter type, return unfiltered
143
+ if filter_name == "compose":
144
+ return self._apply_composed_filters(messages, filter_value, context)
145
+
146
+ return self._apply_named_filter(messages, filter_name, filter_value)
147
+
148
+ @staticmethod
149
+ def _parse_filter_spec(filter_specification: Any) -> tuple[str | None, Any]:
150
+ if not isinstance(filter_specification, tuple) or len(filter_specification) < 2:
151
+ return None, None
152
+
153
+ return filter_specification[0], filter_specification[1]
154
+
155
+ def _apply_composed_filters(
156
+ self,
157
+ messages: list[ModelMessage],
158
+ filter_steps: Any,
159
+ context: Optional[Any],
160
+ ) -> list[ModelMessage]:
161
+ filtered_messages = messages
162
+ for filter_step in filter_steps:
163
+ filtered_messages = self._apply_filter(filtered_messages, filter_step, context)
164
+ return filtered_messages
165
+
166
+ def _apply_named_filter(
167
+ self,
168
+ messages: list[ModelMessage],
169
+ filter_name: str,
170
+ filter_value: Any,
171
+ ) -> list[ModelMessage]:
172
+ filter_function = self._filter_dispatch.get(filter_name)
173
+ if filter_function is None:
168
174
  return messages
169
175
 
176
+ if filter_name == "system_prefix":
177
+ return filter_function(messages)
178
+
179
+ return filter_function(messages, filter_value)
180
+
181
+ @property
182
+ def _filter_dispatch(self) -> dict[str, Any]:
183
+ return {
184
+ "last_n": self._filter_last_n,
185
+ "first_n": self._filter_first_n,
186
+ "token_budget": self._filter_by_token_budget,
187
+ "head_tokens": self._filter_head_tokens,
188
+ "tail_tokens": self._filter_tail_tokens,
189
+ "by_role": self._filter_by_role,
190
+ "system_prefix": self._filter_system_prefix,
191
+ }
192
+
170
193
  def _filter_last_n(
171
194
  self,
172
195
  messages: list[ModelMessage],
@@ -29,13 +29,14 @@ class OutputValidator:
29
29
  """
30
30
 
31
31
  # Type mapping from YAML to Python
32
- TYPE_MAP = {
32
+ SCHEMA_TYPE_TO_PYTHON_TYPE = {
33
33
  "string": str,
34
34
  "number": (int, float),
35
35
  "boolean": bool,
36
36
  "object": dict,
37
37
  "array": list,
38
38
  }
39
+ TYPE_MAP = SCHEMA_TYPE_TO_PYTHON_TYPE
39
40
 
40
41
  @classmethod
41
42
  def _is_scalar_schema(cls, schema: Any) -> bool:
@@ -91,6 +92,15 @@ class OutputValidator:
91
92
  return dict(output.items())
92
93
  return output
93
94
 
95
+ @staticmethod
96
+ def _wrap_validated_output(
97
+ wrapped_result: Any | None,
98
+ validated_payload: Any,
99
+ ) -> Any:
100
+ if wrapped_result is not None:
101
+ return wrapped_result.model_copy(update={"output": validated_payload})
102
+ return validated_payload
103
+
94
104
  def validate(self, output: Any) -> Any:
95
105
  """
96
106
  Validate workflow output against schema.
@@ -108,49 +118,63 @@ class OutputValidator:
108
118
  # while preserving the wrapper (so callers can still access usage/cost/etc.).
109
119
  output, wrapped_result = self._unwrap_result(output)
110
120
 
111
- # If no schema defined, accept any output
112
121
  if not self.schema:
113
- logger.debug("No output schema defined, skipping validation")
114
- validated_payload = self._normalize_unstructured_output(output)
122
+ return self._validate_without_schema(output, wrapped_result)
115
123
 
116
- if wrapped_result is not None:
117
- return wrapped_result.model_copy(update={"output": validated_payload})
118
- return validated_payload
119
-
120
- # Scalar output schema: `output = field.string{...}` etc.
121
124
  if self._is_scalar_schema(self.schema):
122
- # Lua tables are not valid scalar outputs.
123
- if hasattr(output, "items") and not isinstance(output, dict):
124
- output = dict(output.items())
125
-
126
- is_required = self.schema.get("required", False)
127
- if output is None and not is_required:
128
- return None
125
+ return self._validate_scalar_schema(output, wrapped_result)
126
+
127
+ return self._validate_structured_schema(output, wrapped_result)
128
+
129
+ def _validate_without_schema(
130
+ self,
131
+ output: Any,
132
+ wrapped_result: Any | None,
133
+ ) -> Any:
134
+ """Accept any output when no schema is defined."""
135
+ logger.debug("No output schema defined, skipping validation")
136
+ validated_payload = self._normalize_unstructured_output(output)
137
+ return self._wrap_validated_output(wrapped_result, validated_payload)
138
+
139
+ def _validate_scalar_schema(
140
+ self,
141
+ output: Any,
142
+ wrapped_result: Any | None,
143
+ ) -> Any:
144
+ """Validate scalar outputs (`field.string{}` etc.)."""
145
+ # Lua tables are not valid scalar outputs.
146
+ if hasattr(output, "items") and not isinstance(output, dict):
147
+ output = dict(output.items())
148
+
149
+ is_required = self.schema.get("required", False)
150
+ if output is None and not is_required:
151
+ return None
152
+
153
+ expected_type = self.schema.get("type")
154
+ if expected_type and not self._check_type(output, expected_type):
155
+ raise OutputValidationError(
156
+ f"Output should be {expected_type}, got {type(output).__name__}"
157
+ )
129
158
 
130
- expected_type = self.schema.get("type")
131
- if expected_type and not self._check_type(output, expected_type):
159
+ if "enum" in self.schema and self.schema["enum"]:
160
+ allowed_values = self.schema["enum"]
161
+ if output not in allowed_values:
132
162
  raise OutputValidationError(
133
- f"Output should be {expected_type}, got {type(output).__name__}"
163
+ f"Output has invalid value '{output}'. Allowed values: {allowed_values}"
134
164
  )
135
165
 
136
- if "enum" in self.schema and self.schema["enum"]:
137
- allowed_values = self.schema["enum"]
138
- if output not in allowed_values:
139
- raise OutputValidationError(
140
- f"Output has invalid value '{output}'. Allowed values: {allowed_values}"
141
- )
142
-
143
- validated_payload = output
144
- if wrapped_result is not None:
145
- return wrapped_result.model_copy(update={"output": validated_payload})
146
- return validated_payload
166
+ return self._wrap_validated_output(wrapped_result, output)
147
167
 
148
- # Convert Lua tables to dicts recursively
168
+ def _validate_structured_schema(
169
+ self,
170
+ output: Any,
171
+ wrapped_result: Any | None,
172
+ ) -> Any:
173
+ """Validate dict/table outputs against a schema."""
149
174
  if hasattr(output, "items") or isinstance(output, dict):
150
175
  logger.debug("Converting Lua tables to Python dicts recursively")
151
176
  output = self._convert_lua_tables(output)
152
177
 
153
- # Output must be a dict/table
154
178
  if not isinstance(output, dict):
155
179
  raise OutputValidationError(
156
180
  f"Output must be an object/table, got {type(output).__name__}"
@@ -159,7 +183,6 @@ class OutputValidator:
159
183
  validation_errors: list[str] = []
160
184
  validated_output: dict[str, Any] = {}
161
185
 
162
- # Check required fields and validate types
163
186
  for field_name, field_def in self.schema.items():
164
187
  if not isinstance(field_def, dict) or "type" not in field_def:
165
188
  validation_errors.append(
@@ -167,28 +190,23 @@ class OutputValidator:
167
190
  f"Use field.{field_def.get('type', 'string')}{{}} instead."
168
191
  )
169
192
  continue
170
- is_required = bool(field_def.get("required", False))
171
193
 
194
+ is_required = bool(field_def.get("required", False))
172
195
  if is_required and field_name not in output:
173
196
  validation_errors.append(f"Required field '{field_name}' is missing")
174
197
  continue
175
198
 
176
- # Skip validation if field not present and not required
177
199
  if field_name not in output:
178
200
  continue
179
201
 
180
202
  value = output[field_name]
181
-
182
- # Type checking
183
203
  expected_type = field_def.get("type")
184
- if expected_type:
185
- if not self._check_type(value, expected_type):
186
- actual_type = type(value).__name__
187
- validation_errors.append(
188
- f"Field '{field_name}' should be {expected_type}, got {actual_type}"
189
- )
204
+ if expected_type and not self._check_type(value, expected_type):
205
+ actual_type = type(value).__name__
206
+ validation_errors.append(
207
+ f"Field '{field_name}' should be {expected_type}, got {actual_type}"
208
+ )
190
209
 
191
- # Enum validation
192
210
  if "enum" in field_def and field_def["enum"]:
193
211
  allowed_values = field_def["enum"]
194
212
  if value not in allowed_values:
@@ -197,10 +215,8 @@ class OutputValidator:
197
215
  f"Allowed values: {allowed_values}"
198
216
  )
199
217
 
200
- # Add to validated output (only declared fields)
201
218
  validated_output[field_name] = value
202
219
 
203
- # Filter undeclared fields (only return declared fields)
204
220
  for field_name in output:
205
221
  if field_name not in self.schema:
206
222
  logger.debug("Filtering undeclared field '%s' from output", field_name)
@@ -210,9 +226,7 @@ class OutputValidator:
210
226
  raise OutputValidationError(error_message)
211
227
 
212
228
  logger.info("Output validation passed for %s fields", len(validated_output))
213
- if wrapped_result is not None:
214
- return wrapped_result.model_copy(update={"output": validated_output})
215
- return validated_output
229
+ return self._wrap_validated_output(wrapped_result, validated_output)
216
230
 
217
231
  def _check_type(self, value: Any, expected_type: str) -> bool:
218
232
  """