sondera-harness 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. sondera/__init__.py +111 -0
  2. sondera/__main__.py +4 -0
  3. sondera/adk/__init__.py +3 -0
  4. sondera/adk/analyze.py +222 -0
  5. sondera/adk/plugin.py +387 -0
  6. sondera/cli.py +22 -0
  7. sondera/exceptions.py +167 -0
  8. sondera/harness/__init__.py +6 -0
  9. sondera/harness/abc.py +102 -0
  10. sondera/harness/cedar/__init__.py +0 -0
  11. sondera/harness/cedar/harness.py +363 -0
  12. sondera/harness/cedar/schema.py +225 -0
  13. sondera/harness/sondera/__init__.py +0 -0
  14. sondera/harness/sondera/_grpc.py +354 -0
  15. sondera/harness/sondera/harness.py +890 -0
  16. sondera/langgraph/__init__.py +15 -0
  17. sondera/langgraph/analyze.py +543 -0
  18. sondera/langgraph/exceptions.py +19 -0
  19. sondera/langgraph/graph.py +210 -0
  20. sondera/langgraph/middleware.py +454 -0
  21. sondera/proto/google/protobuf/any_pb2.py +37 -0
  22. sondera/proto/google/protobuf/any_pb2.pyi +14 -0
  23. sondera/proto/google/protobuf/any_pb2_grpc.py +24 -0
  24. sondera/proto/google/protobuf/duration_pb2.py +37 -0
  25. sondera/proto/google/protobuf/duration_pb2.pyi +14 -0
  26. sondera/proto/google/protobuf/duration_pb2_grpc.py +24 -0
  27. sondera/proto/google/protobuf/empty_pb2.py +37 -0
  28. sondera/proto/google/protobuf/empty_pb2.pyi +9 -0
  29. sondera/proto/google/protobuf/empty_pb2_grpc.py +24 -0
  30. sondera/proto/google/protobuf/struct_pb2.py +47 -0
  31. sondera/proto/google/protobuf/struct_pb2.pyi +49 -0
  32. sondera/proto/google/protobuf/struct_pb2_grpc.py +24 -0
  33. sondera/proto/google/protobuf/timestamp_pb2.py +37 -0
  34. sondera/proto/google/protobuf/timestamp_pb2.pyi +14 -0
  35. sondera/proto/google/protobuf/timestamp_pb2_grpc.py +24 -0
  36. sondera/proto/google/protobuf/wrappers_pb2.py +53 -0
  37. sondera/proto/google/protobuf/wrappers_pb2.pyi +59 -0
  38. sondera/proto/google/protobuf/wrappers_pb2_grpc.py +24 -0
  39. sondera/proto/sondera/__init__.py +0 -0
  40. sondera/proto/sondera/core/__init__.py +0 -0
  41. sondera/proto/sondera/core/v1/__init__.py +0 -0
  42. sondera/proto/sondera/core/v1/primitives_pb2.py +88 -0
  43. sondera/proto/sondera/core/v1/primitives_pb2.pyi +259 -0
  44. sondera/proto/sondera/core/v1/primitives_pb2_grpc.py +24 -0
  45. sondera/proto/sondera/harness/__init__.py +0 -0
  46. sondera/proto/sondera/harness/v1/__init__.py +0 -0
  47. sondera/proto/sondera/harness/v1/harness_pb2.py +81 -0
  48. sondera/proto/sondera/harness/v1/harness_pb2.pyi +192 -0
  49. sondera/proto/sondera/harness/v1/harness_pb2_grpc.py +498 -0
  50. sondera/py.typed +0 -0
  51. sondera/settings.py +20 -0
  52. sondera/strands/__init__.py +5 -0
  53. sondera/strands/analyze.py +244 -0
  54. sondera/strands/harness.py +333 -0
  55. sondera/tui/__init__.py +0 -0
  56. sondera/tui/app.py +309 -0
  57. sondera/tui/screens/__init__.py +5 -0
  58. sondera/tui/screens/adjudication.py +184 -0
  59. sondera/tui/screens/agent.py +158 -0
  60. sondera/tui/screens/trajectory.py +158 -0
  61. sondera/tui/widgets/__init__.py +23 -0
  62. sondera/tui/widgets/agent_card.py +94 -0
  63. sondera/tui/widgets/agent_list.py +73 -0
  64. sondera/tui/widgets/recent_adjudications.py +52 -0
  65. sondera/tui/widgets/recent_trajectories.py +54 -0
  66. sondera/tui/widgets/summary.py +57 -0
  67. sondera/tui/widgets/tool_card.py +33 -0
  68. sondera/tui/widgets/violation_panel.py +72 -0
  69. sondera/tui/widgets/violations_list.py +78 -0
  70. sondera/tui/widgets/violations_summary.py +104 -0
  71. sondera/types.py +346 -0
  72. sondera_harness-0.6.0.dist-info/METADATA +323 -0
  73. sondera_harness-0.6.0.dist-info/RECORD +77 -0
  74. sondera_harness-0.6.0.dist-info/WHEEL +5 -0
  75. sondera_harness-0.6.0.dist-info/entry_points.txt +2 -0
  76. sondera_harness-0.6.0.dist-info/licenses/LICENSE +21 -0
  77. sondera_harness-0.6.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,210 @@
1
+ """LangGraph state graph wrapper with Sondera trajectory tracking."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Any
7
+
8
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
9
+
10
+ from sondera.harness import Harness
11
+ from sondera.types import (
12
+ Adjudication,
13
+ Content,
14
+ Decision,
15
+ PromptContent,
16
+ Role,
17
+ Stage,
18
+ ToolResponseContent,
19
+ )
20
+
21
+ from .exceptions import GuardrailViolationError
22
+
23
+ LOGGER = logging.getLogger(__name__)
24
+
25
+
26
+ class SonderaGraph:
27
+ """Wrapper for LangGraph compiled graphs that tracks node executions.
28
+
29
+ Uses LangGraph's streaming API (astream) to intercept each node execution
30
+ and record it as a trajectory step. This enables policy enforcement and
31
+ observability for state-based workflows.
32
+
33
+ Example:
34
+ ```python
35
+ from langgraph.graph import StateGraph, END
36
+ from sondera.langgraph import SonderaGraphWrapper
37
+ from sondera.harness import Harness
38
+
39
+ # Build your graph
40
+ graph = StateGraph(MyState)
41
+ graph.add_node("node1", my_function)
42
+ graph.add_edge("node1", END)
43
+ compiled = graph.compile()
44
+
45
+ # Create harness
46
+ harness = Harness(
47
+ sondera_harness_endpoint="localhost:50051",
48
+ agent=agent,
49
+ )
50
+
51
+ # Wrap with Sondera
52
+ wrapped = SonderaGraphWrapper(compiled, harness=harness)
53
+
54
+ # Execute - node executions will be tracked
55
+ result = await wrapped.ainvoke(initial_state)
56
+ ```
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ compiled_graph: Any,
62
+ *,
63
+ harness: Harness,
64
+ track_nodes: bool = True,
65
+ enforce: bool = True,
66
+ ) -> None:
67
+ """Initialize the graph wrapper.
68
+
69
+ Args:
70
+ compiled_graph: The LangGraph compiled graph to wrap
71
+ harness: Sondera harness for policy enforcement
72
+ track_nodes: Whether to track node executions (default: True)
73
+ enforce: Whether to enforce policy decisions (default: True)
74
+ """
75
+ self._graph = compiled_graph
76
+ self._harness = harness
77
+ self._track_nodes = track_nodes
78
+ self._enforce = enforce
79
+ self._logger = LOGGER
80
+
81
+ async def ainvoke(
82
+ self,
83
+ input: dict[str, Any],
84
+ config: dict[str, Any] | None = None,
85
+ ) -> dict[str, Any]:
86
+ """Execute the graph with trajectory tracking via streaming.
87
+
88
+ Args:
89
+ input: Initial state for the graph
90
+ config: Optional configuration dict
91
+
92
+ Returns:
93
+ Final state after graph execution
94
+ """
95
+ # Initialize trajectory
96
+ await self._harness.initialize(agent=self._harness._agent)
97
+
98
+ # Record initial user message if present
99
+ if "messages" in input and input["messages"]:
100
+ initial_msg = input["messages"][0]
101
+ if isinstance(initial_msg, HumanMessage | BaseMessage):
102
+ await self._record_step(
103
+ content=PromptContent(text=_message_to_text(initial_msg)),
104
+ role=Role.USER,
105
+ stage=Stage.PRE_MODEL,
106
+ node="user_input",
107
+ )
108
+
109
+ # Use streaming to track each node execution
110
+ final_state = dict(input) if isinstance(input, dict) else {}
111
+ if self._track_nodes:
112
+ async for chunk in self._graph.astream(input, config=config):
113
+ # chunk is {node_name: node_state_output}
114
+ for node_name, node_state in chunk.items():
115
+ await self._record_node_execution(
116
+ node_name=node_name,
117
+ node_state=node_state,
118
+ )
119
+ # Merge node updates into accumulated state
120
+ if isinstance(node_state, dict):
121
+ final_state.update(node_state)
122
+ else:
123
+ final_state = node_state
124
+ else:
125
+ final_state = await self._graph.ainvoke(input, config=config)
126
+
127
+ # Record final output if present
128
+ if final_state and "messages" in final_state and final_state["messages"]:
129
+ final_msg = final_state["messages"][-1]
130
+ if isinstance(final_msg, AIMessage | BaseMessage):
131
+ await self._record_step(
132
+ content=PromptContent(text=_message_to_text(final_msg)),
133
+ role=Role.MODEL,
134
+ stage=Stage.POST_MODEL,
135
+ node="final_output",
136
+ )
137
+
138
+ # Finalize trajectory
139
+ await self._harness.finalize()
140
+
141
+ return final_state
142
+
143
+ async def _record_node_execution(
144
+ self,
145
+ node_name: str,
146
+ node_state: dict[str, Any],
147
+ ) -> None:
148
+ """Record a node execution as a trajectory step."""
149
+ # Extract meaningful content from the node's state update
150
+ if "messages" in node_state and node_state["messages"]:
151
+ last_msg = node_state["messages"][-1]
152
+ if isinstance(last_msg, BaseMessage):
153
+ content = _message_to_text(last_msg)
154
+ else:
155
+ content = str(last_msg)
156
+ else:
157
+ # For non-message nodes, summarize the state change
158
+ content = f"Node '{node_name}' updated state"
159
+
160
+ await self._record_step(
161
+ content=ToolResponseContent(tool_id=node_name, response=content),
162
+ role=Role.TOOL, # Nodes are like tool executions
163
+ stage=Stage.POST_TOOL,
164
+ node=node_name,
165
+ )
166
+
167
+ async def _record_step(
168
+ self,
169
+ *,
170
+ content: Content,
171
+ role: Role,
172
+ stage: Stage,
173
+ node: str,
174
+ ) -> Adjudication:
175
+ """Record and adjudicate a trajectory step."""
176
+ # Adjudicate with policy engine via harness
177
+ adjudication = await self._harness.adjudicate(
178
+ stage=stage,
179
+ role=role,
180
+ content=content,
181
+ )
182
+
183
+ # Enforce DENY decisions if enabled
184
+ if adjudication.decision is Decision.DENY and self._enforce:
185
+ raise GuardrailViolationError(
186
+ stage=stage,
187
+ node=node,
188
+ reason=adjudication.reason,
189
+ )
190
+
191
+ return adjudication
192
+
193
+ def invoke(
194
+ self, input: dict[str, Any], config: dict[str, Any] | None = None
195
+ ) -> dict[str, Any]:
196
+ """Synchronous version of ainvoke (not recommended for production)."""
197
+ import asyncio
198
+
199
+ return asyncio.run(self.ainvoke(input, config))
200
+
201
+
202
+ def _message_to_text(message: BaseMessage | Any) -> str:
203
+ """Extract text content from a message."""
204
+ if isinstance(message, BaseMessage):
205
+ if isinstance(message.content, str):
206
+ return message.content
207
+ return str(message.content)
208
+ if isinstance(message, dict) and "content" in message:
209
+ return str(message["content"])
210
+ return str(message)
@@ -0,0 +1,454 @@
1
+ """Sondera Harness Middleware for LangGraph."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from collections.abc import Awaitable, Callable
7
+ from enum import Enum
8
+ from typing import Any
9
+
10
+ from langchain.agents import AgentState
11
+ from langchain.agents.middleware import (
12
+ AgentMiddleware,
13
+ ModelRequest,
14
+ ModelResponse,
15
+ hook_config,
16
+ )
17
+ from langchain.messages import ToolMessage
18
+ from langchain.tools.tool_node import ToolCallRequest
19
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
20
+ from langgraph.runtime import Runtime
21
+ from langgraph.types import Command
22
+
23
+ try:
24
+ from langgraph.graph import END
25
+ except ImportError:
26
+ # Fallback for older versions
27
+ END = "__end__"
28
+
29
+ from sondera.harness import Harness
30
+ from sondera.types import (
31
+ PromptContent,
32
+ Role,
33
+ Stage,
34
+ ToolRequestContent,
35
+ ToolResponseContent,
36
+ )
37
+
38
+ _LOGGER = logging.getLogger(__name__)
39
+
40
+
41
+ class Strategy(str, Enum):
42
+ """Strategy for handling policy violations."""
43
+
44
+ BLOCK = "block"
45
+ """Jump to end immediately when a policy violation is detected."""
46
+ STEER = "steer"
47
+ """Allow continuation with modified content when a policy violation is detected."""
48
+
49
+
50
+ class State(AgentState):
51
+ """Agent state with additional Sondera Harness-related fields."""
52
+
53
+ trajectory_id: str | None
54
+
55
+
56
+ class SonderaHarnessMiddleware(AgentMiddleware[State]):
57
+ """LangGraph middleware that integrates with Sondera Harness for policy enforcement.
58
+
59
+ This middleware intercepts agent execution at key points (before/after agent,
60
+ model calls, tool calls) and delegates policy evaluation to the Sondera Harness
61
+ Service. Based on the adjudication result, it can either allow execution to
62
+ proceed, block and jump to end, or steer the response with modified content.
63
+
64
+ Example:
65
+ ```python
66
+ from sondera.langgraph.middleware import SonderaHarnessMiddleware, Strategy
67
+ from sondera.harness import RemoteHarness
68
+ from sondera.types import Agent
69
+ from langchain.agents import create_agent
70
+
71
+ # Create a harness instance
72
+ harness = RemoteHarness(
73
+ endpoint="localhost:50051",
74
+ organization_id="my-tenant",
75
+ agent=Agent(
76
+ id="my-agent",
77
+ provider_id="langchain",
78
+ name="My Agent",
79
+ description="An agent with Sondera governance",
80
+ instruction="Be helpful",
81
+ tools=[],
82
+ ),
83
+ )
84
+
85
+ # Create middleware with the harness
86
+ middleware = SonderaHarnessMiddleware(
87
+ harness=harness,
88
+ strategy=Strategy.BLOCK,
89
+ )
90
+
91
+ agent = create_agent(
92
+ model="gpt-4o",
93
+ tools=[...],
94
+ middleware=[middleware],
95
+ )
96
+ ```
97
+ """
98
+
99
+ state_schema = State
100
+
101
+ def __init__(
102
+ self,
103
+ harness: Harness,
104
+ *,
105
+ strategy: Strategy = Strategy.BLOCK,
106
+ logger: logging.Logger | None = None,
107
+ ) -> None:
108
+ """Initialize the Sondera Harness Middleware.
109
+
110
+ Args:
111
+ harness: The Sondera Harness instance to use
112
+ strategy: How to handle policy violations (BLOCK or STEER)
113
+ """
114
+ self._harness = harness
115
+ self._strategy = strategy
116
+ self._log = logger or _LOGGER
117
+ super().__init__()
118
+
119
+ @hook_config(can_jump_to=["end"])
120
+ async def abefore_agent(
121
+ self, state: State, runtime: Runtime
122
+ ) -> dict[str, Any] | None:
123
+ """Execute before agent starts.
124
+
125
+ Initializes the trajectory and evaluates the user's input message
126
+ against policies before the agent begins processing.
127
+
128
+ Args:
129
+ state: The current agent state containing messages
130
+ runtime: The LangGraph runtime
131
+
132
+ Returns:
133
+ None to continue, or a dict with state updates (including optional jump_to)
134
+ """
135
+ trajectory_id = state.get("trajectory_id")
136
+ updates = {}
137
+
138
+ if trajectory_id and trajectory_id.strip(): # Check for non-empty string
139
+ # Resume an existing trajectory.
140
+ await self._harness.resume(trajectory_id)
141
+ self._log.debug(
142
+ f"[SonderaHarness] Resumed trajectory: {self._harness.trajectory_id}"
143
+ )
144
+ else:
145
+ # Initialize a new trajectory if needed.
146
+ if self._harness.trajectory_id is None:
147
+ await self._harness.initialize()
148
+ updates["trajectory_id"] = self._harness.trajectory_id
149
+ self._log.debug(
150
+ f"[SonderaHarness] Initialized trajectory: {self._harness.trajectory_id}"
151
+ )
152
+
153
+ # Extract user message from state
154
+ user_message = _extract_last_user_message(state)
155
+ if user_message is None:
156
+ self._log.debug(
157
+ "[SonderaHarness] No user message found in state, skipping pre-agent check"
158
+ )
159
+ # Still return trajectory_id if we just created one
160
+ return updates if updates else None
161
+
162
+ content = _message_to_text(user_message)
163
+ self._log.debug(
164
+ f"[SonderaHarness] Evaluating user input for trajectory {self._harness.trajectory_id}"
165
+ )
166
+
167
+ adjudication = await self._harness.adjudicate(
168
+ Stage.PRE_MODEL,
169
+ Role.USER,
170
+ PromptContent(text=content),
171
+ )
172
+ self._log.info(
173
+ f"[SonderaHarness] Before Agent Adjudication for trajectory {self._harness.trajectory_id}"
174
+ )
175
+
176
+ if adjudication.is_denied:
177
+ self._log.warning(
178
+ f"[SonderaHarness] Policy violation detected (strategy={self._strategy.value}): "
179
+ f"{adjudication.reason}"
180
+ )
181
+ if self._strategy == Strategy.BLOCK:
182
+ # BLOCK: Jump to end immediately with policy message
183
+ return {
184
+ "messages": [AIMessage(content=adjudication.reason)],
185
+ "jump_to": "end",
186
+ **updates, # Include trajectory_id in the response
187
+ }
188
+ # STEER: Replace user message with policy guidance and continue
189
+ return {
190
+ "messages": [
191
+ AIMessage(
192
+ content=f"Policy violation in user message: {adjudication.reason}"
193
+ )
194
+ ],
195
+ **updates, # Include trajectory_id in the response
196
+ }
197
+
198
+ # Return trajectory_id if we just created one
199
+ return updates if updates else None
200
+
201
+ async def awrap_model_call(
202
+ self,
203
+ request: ModelRequest,
204
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
205
+ ) -> ModelResponse:
206
+ """Wrap model calls with policy evaluation.
207
+
208
+ Evaluates the model request before calling the model, then evaluates
209
+ the model's response after it returns.
210
+
211
+ Args:
212
+ request: The model request containing messages and configuration
213
+ handler: The handler function to call the actual model
214
+
215
+ Returns:
216
+ The model response, potentially modified based on policy
217
+ """
218
+ if isinstance(request.messages[-1], AIMessage):
219
+ # Last message is an AIMessage, so we need to adjudicate it. HumanMessage was checked in abefore_agent.
220
+ _LOGGER.debug(
221
+ f"[SonderaHarness] Pre-model check for trajectory {self._harness.trajectory_id} {request.messages}"
222
+ )
223
+ pre_adjudication = await self._harness.adjudicate(
224
+ Stage.PRE_MODEL,
225
+ Role.MODEL,
226
+ PromptContent(text=_message_to_text(request.messages[-1])),
227
+ )
228
+
229
+ if pre_adjudication.is_denied:
230
+ _LOGGER.warning(
231
+ f"[SonderaHarness] Pre-model policy violation (strategy={self._strategy.value}): "
232
+ f"{pre_adjudication.reason}"
233
+ )
234
+ message = AIMessage(
235
+ content=f"Replaced message due to policy violation: {pre_adjudication.reason}"
236
+ )
237
+ if self._strategy == Strategy.STEER:
238
+ # STEER: Replace the last message with the policy message
239
+ request.messages[-1] = message
240
+ else:
241
+ # BLOCK: Return early with the policy message
242
+ return ModelResponse(
243
+ result=[message],
244
+ structured_response=None,
245
+ )
246
+
247
+ # Call the actual model
248
+ response: ModelResponse = await handler(request)
249
+
250
+ # Post-model check on each AI message in the response
251
+ sanitized_messages: list[BaseMessage] = []
252
+ for message in response.result:
253
+ if isinstance(message, AIMessage):
254
+ post_adjudication = await self._harness.adjudicate(
255
+ Stage.POST_MODEL,
256
+ Role.MODEL,
257
+ PromptContent(text=message.text),
258
+ )
259
+ self._log.info(
260
+ f"[SonderaHarness] Post-model Adjudication for trajectory {self._harness.trajectory_id}"
261
+ )
262
+ if post_adjudication.is_denied:
263
+ self._log.warning(
264
+ f"[SonderaHarness] Post-model policy violation (strategy={self._strategy.value}): "
265
+ f"{post_adjudication.reason}"
266
+ )
267
+ message = AIMessage(
268
+ content=f"Replaced message due to policy violation: {post_adjudication.reason}"
269
+ )
270
+ if self._strategy == Strategy.STEER:
271
+ # STEER: Replace the message with the policy message
272
+ sanitized_messages.append(message)
273
+ else:
274
+ # BLOCK: Return early with the policy message
275
+ return ModelResponse(
276
+ result=[message],
277
+ structured_response=response.structured_response,
278
+ )
279
+ else:
280
+ sanitized_messages.append(message)
281
+ else:
282
+ self._log.debug(
283
+ f"[SonderaHarness] Non-AIMessage in response: {message} in trajectory {self._harness.trajectory_id}"
284
+ )
285
+ sanitized_messages.append(message)
286
+
287
+ return ModelResponse(
288
+ result=sanitized_messages,
289
+ structured_response=response.structured_response,
290
+ )
291
+
292
+ async def awrap_tool_call(
293
+ self,
294
+ request: ToolCallRequest,
295
+ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
296
+ ) -> ToolMessage | Command:
297
+ """Wrap tool calls with policy evaluation.
298
+
299
+ Evaluates the tool request before execution, then evaluates
300
+ the tool's response after it returns.
301
+
302
+ Args:
303
+ request: The tool call request containing tool name and arguments
304
+ handler: The handler function to execute the actual tool
305
+
306
+ Returns:
307
+ The tool response, potentially modified based on policy
308
+ """
309
+ tool_name = request.tool_call.get("name", "unknown_tool")
310
+ tool_args = request.tool_call.get("args", {})
311
+ tool_call_id = request.tool_call.get("id", "")
312
+
313
+ # Pre-tool check
314
+ self._log.debug(
315
+ f"[SonderaHarness] Pre-tool check for {tool_name} in trajectory {self._harness.trajectory_id}"
316
+ )
317
+ pre_adjudication = await self._harness.adjudicate(
318
+ Stage.PRE_TOOL,
319
+ Role.TOOL,
320
+ ToolRequestContent(tool_id=tool_name, args=tool_args),
321
+ )
322
+
323
+ self._log.info(
324
+ f"[SonderaHarness] Before Tool Adjudication for trajectory {self._harness.trajectory_id}"
325
+ )
326
+
327
+ if pre_adjudication.is_denied:
328
+ self._log.warning(
329
+ f"[SonderaHarness] Pre-tool policy violation for {tool_name} "
330
+ f"(strategy={self._strategy.value}): {pre_adjudication.reason}"
331
+ )
332
+ if self._strategy == Strategy.BLOCK:
333
+ # BLOCK: Jump to end using Command
334
+ return Command(
335
+ goto=END,
336
+ update={
337
+ "messages": [
338
+ ToolMessage(
339
+ content=f"Tool execution was blocked. {pre_adjudication.reason}",
340
+ tool_call_id=tool_call_id,
341
+ name=tool_name,
342
+ )
343
+ ]
344
+ },
345
+ )
346
+ # STEER: Return tool message with policy violation instead of allowing execution
347
+ return ToolMessage(
348
+ content=f"Tool execution modified due to policy concern: {pre_adjudication.reason}",
349
+ tool_call_id=tool_call_id,
350
+ name=tool_name,
351
+ )
352
+
353
+ # Execute the actual tool
354
+ result = await handler(request)
355
+
356
+ # Post-tool check
357
+ if isinstance(result, ToolMessage):
358
+ output_text = _tool_message_to_text(result)
359
+
360
+ post_adjudication = await self._harness.adjudicate(
361
+ Stage.POST_TOOL,
362
+ Role.TOOL,
363
+ ToolResponseContent(tool_id=tool_name, response=output_text),
364
+ )
365
+
366
+ self._log.info(
367
+ f"[SonderaHarness] After Tool Adjudication for trajectory {self._harness.trajectory_id}"
368
+ )
369
+
370
+ if post_adjudication.is_denied:
371
+ self._log.warning(
372
+ f"[SonderaHarness] Post-tool policy violation for {tool_name} "
373
+ f"(strategy={self._strategy.value}): {post_adjudication.reason}"
374
+ )
375
+ if self._strategy == Strategy.BLOCK:
376
+ # BLOCK: Jump to end using Command
377
+ return Command(
378
+ goto=END,
379
+ update={
380
+ "messages": [
381
+ ToolMessage(
382
+ content=f"Tool result was blocked. {post_adjudication.reason}",
383
+ tool_call_id=tool_call_id,
384
+ name=tool_name,
385
+ )
386
+ ]
387
+ },
388
+ )
389
+ # STEER: Return modified ToolMessage with policy violation message
390
+ return ToolMessage(
391
+ content=f"Tool result was modified. {post_adjudication.reason}",
392
+ tool_call_id=tool_call_id,
393
+ name=tool_name,
394
+ )
395
+
396
+ return result
397
+
398
+ async def aafter_agent(
399
+ self, state: AgentState, runtime: Runtime
400
+ ) -> dict[str, Any] | None:
401
+ """Execute after agent completes.
402
+
403
+ Args:
404
+ state: The final agent state containing messages
405
+ runtime: The LangGraph runtime
406
+
407
+ Returns:
408
+ None to continue, or a dict with state updates
409
+ """
410
+ # Finalize the trajectory
411
+ trajectory_id = self._harness.trajectory_id
412
+ await self._harness.finalize()
413
+ self._log.info(f"[SonderaHarness] Trajectory finalized: {trajectory_id}")
414
+
415
+ # Preserve trajectory_id in final state for next conversation
416
+ return {"trajectory_id": trajectory_id} if trajectory_id else None
417
+
418
+
419
+ def _extract_last_user_message(state: AgentState) -> BaseMessage | None:
420
+ """Extract the last user message from agent state."""
421
+ messages = state.get("messages", [])
422
+ if not messages:
423
+ return None
424
+
425
+ # Look for the last HumanMessage
426
+ for message in reversed(messages):
427
+ if isinstance(message, HumanMessage):
428
+ return message
429
+ if isinstance(message, dict) and message.get("role") == "user":
430
+ return HumanMessage(content=message.get("content", ""))
431
+
432
+ # Fallback to last message if it looks like user input
433
+ last = messages[-1]
434
+ if isinstance(last, dict):
435
+ return HumanMessage(content=last.get("content", ""))
436
+ return None
437
+
438
+
439
+ def _message_to_text(message: BaseMessage) -> str:
440
+ """Convert a message to text content."""
441
+ if isinstance(message.content, str):
442
+ return message.content
443
+ if isinstance(message.content, list):
444
+ return " ".join(str(chunk) for chunk in message.content)
445
+ return str(message.content)
446
+
447
+
448
+ def _tool_message_to_text(message: ToolMessage) -> str:
449
+ """Convert a tool message to text content."""
450
+ if isinstance(message.content, str):
451
+ return message.content
452
+ if isinstance(message.content, list):
453
+ return " ".join(str(chunk) for chunk in message.content)
454
+ return str(message.content)
@@ -0,0 +1,37 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
4
+ # source: google/protobuf/any.proto
5
+ # Protobuf Python Version: 6.31.1
6
+ """Generated protocol buffer code."""
7
+ from google.protobuf import descriptor as _descriptor
8
+ from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
10
+ from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 6,
15
+ 31,
16
+ 1,
17
+ '',
18
+ 'google/protobuf/any.proto'
19
+ )
20
+ # @@protoc_insertion_point(imports)
21
+
22
+ _sym_db = _symbol_database.Default()
23
+
24
+
25
+
26
+
27
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19google/protobuf/any.proto\x12\x0fgoogle.protobuf\"&\n\x03\x41ny\x12\x10\n\x08type_url\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x42v\n\x13\x63om.google.protobufB\x08\x41nyProtoP\x01Z,google.golang.org/protobuf/types/known/anypb\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
28
+
29
+ _globals = globals()
30
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.any_pb2', _globals)
32
+ if not _descriptor._USE_C_DESCRIPTORS:
33
+ _globals['DESCRIPTOR']._loaded_options = None
34
+ _globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\010AnyProtoP\001Z,google.golang.org/protobuf/types/known/anypb\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
35
+ _globals['_ANY']._serialized_start=46
36
+ _globals['_ANY']._serialized_end=84
37
+ # @@protoc_insertion_point(module_scope)