sondera-harness 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sondera/__init__.py +111 -0
- sondera/__main__.py +4 -0
- sondera/adk/__init__.py +3 -0
- sondera/adk/analyze.py +222 -0
- sondera/adk/plugin.py +387 -0
- sondera/cli.py +22 -0
- sondera/exceptions.py +167 -0
- sondera/harness/__init__.py +6 -0
- sondera/harness/abc.py +102 -0
- sondera/harness/cedar/__init__.py +0 -0
- sondera/harness/cedar/harness.py +363 -0
- sondera/harness/cedar/schema.py +225 -0
- sondera/harness/sondera/__init__.py +0 -0
- sondera/harness/sondera/_grpc.py +354 -0
- sondera/harness/sondera/harness.py +890 -0
- sondera/langgraph/__init__.py +15 -0
- sondera/langgraph/analyze.py +543 -0
- sondera/langgraph/exceptions.py +19 -0
- sondera/langgraph/graph.py +210 -0
- sondera/langgraph/middleware.py +454 -0
- sondera/proto/google/protobuf/any_pb2.py +37 -0
- sondera/proto/google/protobuf/any_pb2.pyi +14 -0
- sondera/proto/google/protobuf/any_pb2_grpc.py +24 -0
- sondera/proto/google/protobuf/duration_pb2.py +37 -0
- sondera/proto/google/protobuf/duration_pb2.pyi +14 -0
- sondera/proto/google/protobuf/duration_pb2_grpc.py +24 -0
- sondera/proto/google/protobuf/empty_pb2.py +37 -0
- sondera/proto/google/protobuf/empty_pb2.pyi +9 -0
- sondera/proto/google/protobuf/empty_pb2_grpc.py +24 -0
- sondera/proto/google/protobuf/struct_pb2.py +47 -0
- sondera/proto/google/protobuf/struct_pb2.pyi +49 -0
- sondera/proto/google/protobuf/struct_pb2_grpc.py +24 -0
- sondera/proto/google/protobuf/timestamp_pb2.py +37 -0
- sondera/proto/google/protobuf/timestamp_pb2.pyi +14 -0
- sondera/proto/google/protobuf/timestamp_pb2_grpc.py +24 -0
- sondera/proto/google/protobuf/wrappers_pb2.py +53 -0
- sondera/proto/google/protobuf/wrappers_pb2.pyi +59 -0
- sondera/proto/google/protobuf/wrappers_pb2_grpc.py +24 -0
- sondera/proto/sondera/__init__.py +0 -0
- sondera/proto/sondera/core/__init__.py +0 -0
- sondera/proto/sondera/core/v1/__init__.py +0 -0
- sondera/proto/sondera/core/v1/primitives_pb2.py +88 -0
- sondera/proto/sondera/core/v1/primitives_pb2.pyi +259 -0
- sondera/proto/sondera/core/v1/primitives_pb2_grpc.py +24 -0
- sondera/proto/sondera/harness/__init__.py +0 -0
- sondera/proto/sondera/harness/v1/__init__.py +0 -0
- sondera/proto/sondera/harness/v1/harness_pb2.py +81 -0
- sondera/proto/sondera/harness/v1/harness_pb2.pyi +192 -0
- sondera/proto/sondera/harness/v1/harness_pb2_grpc.py +498 -0
- sondera/py.typed +0 -0
- sondera/settings.py +20 -0
- sondera/strands/__init__.py +5 -0
- sondera/strands/analyze.py +244 -0
- sondera/strands/harness.py +333 -0
- sondera/tui/__init__.py +0 -0
- sondera/tui/app.py +309 -0
- sondera/tui/screens/__init__.py +5 -0
- sondera/tui/screens/adjudication.py +184 -0
- sondera/tui/screens/agent.py +158 -0
- sondera/tui/screens/trajectory.py +158 -0
- sondera/tui/widgets/__init__.py +23 -0
- sondera/tui/widgets/agent_card.py +94 -0
- sondera/tui/widgets/agent_list.py +73 -0
- sondera/tui/widgets/recent_adjudications.py +52 -0
- sondera/tui/widgets/recent_trajectories.py +54 -0
- sondera/tui/widgets/summary.py +57 -0
- sondera/tui/widgets/tool_card.py +33 -0
- sondera/tui/widgets/violation_panel.py +72 -0
- sondera/tui/widgets/violations_list.py +78 -0
- sondera/tui/widgets/violations_summary.py +104 -0
- sondera/types.py +346 -0
- sondera_harness-0.6.0.dist-info/METADATA +323 -0
- sondera_harness-0.6.0.dist-info/RECORD +77 -0
- sondera_harness-0.6.0.dist-info/WHEEL +5 -0
- sondera_harness-0.6.0.dist-info/entry_points.txt +2 -0
- sondera_harness-0.6.0.dist-info/licenses/LICENSE +21 -0
- sondera_harness-0.6.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
"""LangGraph state graph wrapper with Sondera trajectory tracking."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
|
9
|
+
|
|
10
|
+
from sondera.harness import Harness
|
|
11
|
+
from sondera.types import (
|
|
12
|
+
Adjudication,
|
|
13
|
+
Content,
|
|
14
|
+
Decision,
|
|
15
|
+
PromptContent,
|
|
16
|
+
Role,
|
|
17
|
+
Stage,
|
|
18
|
+
ToolResponseContent,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from .exceptions import GuardrailViolationError
|
|
22
|
+
|
|
23
|
+
LOGGER = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SonderaGraph:
|
|
27
|
+
"""Wrapper for LangGraph compiled graphs that tracks node executions.
|
|
28
|
+
|
|
29
|
+
Uses LangGraph's streaming API (astream) to intercept each node execution
|
|
30
|
+
and record it as a trajectory step. This enables policy enforcement and
|
|
31
|
+
observability for state-based workflows.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
```python
|
|
35
|
+
from langgraph.graph import StateGraph, END
|
|
36
|
+
from sondera.langgraph import SonderaGraphWrapper
|
|
37
|
+
from sondera.harness import Harness
|
|
38
|
+
|
|
39
|
+
# Build your graph
|
|
40
|
+
graph = StateGraph(MyState)
|
|
41
|
+
graph.add_node("node1", my_function)
|
|
42
|
+
graph.add_edge("node1", END)
|
|
43
|
+
compiled = graph.compile()
|
|
44
|
+
|
|
45
|
+
# Create harness
|
|
46
|
+
harness = Harness(
|
|
47
|
+
sondera_harness_endpoint="localhost:50051",
|
|
48
|
+
agent=agent,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Wrap with Sondera
|
|
52
|
+
wrapped = SonderaGraphWrapper(compiled, harness=harness)
|
|
53
|
+
|
|
54
|
+
# Execute - node executions will be tracked
|
|
55
|
+
result = await wrapped.ainvoke(initial_state)
|
|
56
|
+
```
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
compiled_graph: Any,
|
|
62
|
+
*,
|
|
63
|
+
harness: Harness,
|
|
64
|
+
track_nodes: bool = True,
|
|
65
|
+
enforce: bool = True,
|
|
66
|
+
) -> None:
|
|
67
|
+
"""Initialize the graph wrapper.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
compiled_graph: The LangGraph compiled graph to wrap
|
|
71
|
+
harness: Sondera harness for policy enforcement
|
|
72
|
+
track_nodes: Whether to track node executions (default: True)
|
|
73
|
+
enforce: Whether to enforce policy decisions (default: True)
|
|
74
|
+
"""
|
|
75
|
+
self._graph = compiled_graph
|
|
76
|
+
self._harness = harness
|
|
77
|
+
self._track_nodes = track_nodes
|
|
78
|
+
self._enforce = enforce
|
|
79
|
+
self._logger = LOGGER
|
|
80
|
+
|
|
81
|
+
async def ainvoke(
|
|
82
|
+
self,
|
|
83
|
+
input: dict[str, Any],
|
|
84
|
+
config: dict[str, Any] | None = None,
|
|
85
|
+
) -> dict[str, Any]:
|
|
86
|
+
"""Execute the graph with trajectory tracking via streaming.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
input: Initial state for the graph
|
|
90
|
+
config: Optional configuration dict
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Final state after graph execution
|
|
94
|
+
"""
|
|
95
|
+
# Initialize trajectory
|
|
96
|
+
await self._harness.initialize(agent=self._harness._agent)
|
|
97
|
+
|
|
98
|
+
# Record initial user message if present
|
|
99
|
+
if "messages" in input and input["messages"]:
|
|
100
|
+
initial_msg = input["messages"][0]
|
|
101
|
+
if isinstance(initial_msg, HumanMessage | BaseMessage):
|
|
102
|
+
await self._record_step(
|
|
103
|
+
content=PromptContent(text=_message_to_text(initial_msg)),
|
|
104
|
+
role=Role.USER,
|
|
105
|
+
stage=Stage.PRE_MODEL,
|
|
106
|
+
node="user_input",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Use streaming to track each node execution
|
|
110
|
+
final_state = dict(input) if isinstance(input, dict) else {}
|
|
111
|
+
if self._track_nodes:
|
|
112
|
+
async for chunk in self._graph.astream(input, config=config):
|
|
113
|
+
# chunk is {node_name: node_state_output}
|
|
114
|
+
for node_name, node_state in chunk.items():
|
|
115
|
+
await self._record_node_execution(
|
|
116
|
+
node_name=node_name,
|
|
117
|
+
node_state=node_state,
|
|
118
|
+
)
|
|
119
|
+
# Merge node updates into accumulated state
|
|
120
|
+
if isinstance(node_state, dict):
|
|
121
|
+
final_state.update(node_state)
|
|
122
|
+
else:
|
|
123
|
+
final_state = node_state
|
|
124
|
+
else:
|
|
125
|
+
final_state = await self._graph.ainvoke(input, config=config)
|
|
126
|
+
|
|
127
|
+
# Record final output if present
|
|
128
|
+
if final_state and "messages" in final_state and final_state["messages"]:
|
|
129
|
+
final_msg = final_state["messages"][-1]
|
|
130
|
+
if isinstance(final_msg, AIMessage | BaseMessage):
|
|
131
|
+
await self._record_step(
|
|
132
|
+
content=PromptContent(text=_message_to_text(final_msg)),
|
|
133
|
+
role=Role.MODEL,
|
|
134
|
+
stage=Stage.POST_MODEL,
|
|
135
|
+
node="final_output",
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Finalize trajectory
|
|
139
|
+
await self._harness.finalize()
|
|
140
|
+
|
|
141
|
+
return final_state
|
|
142
|
+
|
|
143
|
+
async def _record_node_execution(
|
|
144
|
+
self,
|
|
145
|
+
node_name: str,
|
|
146
|
+
node_state: dict[str, Any],
|
|
147
|
+
) -> None:
|
|
148
|
+
"""Record a node execution as a trajectory step."""
|
|
149
|
+
# Extract meaningful content from the node's state update
|
|
150
|
+
if "messages" in node_state and node_state["messages"]:
|
|
151
|
+
last_msg = node_state["messages"][-1]
|
|
152
|
+
if isinstance(last_msg, BaseMessage):
|
|
153
|
+
content = _message_to_text(last_msg)
|
|
154
|
+
else:
|
|
155
|
+
content = str(last_msg)
|
|
156
|
+
else:
|
|
157
|
+
# For non-message nodes, summarize the state change
|
|
158
|
+
content = f"Node '{node_name}' updated state"
|
|
159
|
+
|
|
160
|
+
await self._record_step(
|
|
161
|
+
content=ToolResponseContent(tool_id=node_name, response=content),
|
|
162
|
+
role=Role.TOOL, # Nodes are like tool executions
|
|
163
|
+
stage=Stage.POST_TOOL,
|
|
164
|
+
node=node_name,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
async def _record_step(
|
|
168
|
+
self,
|
|
169
|
+
*,
|
|
170
|
+
content: Content,
|
|
171
|
+
role: Role,
|
|
172
|
+
stage: Stage,
|
|
173
|
+
node: str,
|
|
174
|
+
) -> Adjudication:
|
|
175
|
+
"""Record and adjudicate a trajectory step."""
|
|
176
|
+
# Adjudicate with policy engine via harness
|
|
177
|
+
adjudication = await self._harness.adjudicate(
|
|
178
|
+
stage=stage,
|
|
179
|
+
role=role,
|
|
180
|
+
content=content,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Enforce DENY decisions if enabled
|
|
184
|
+
if adjudication.decision is Decision.DENY and self._enforce:
|
|
185
|
+
raise GuardrailViolationError(
|
|
186
|
+
stage=stage,
|
|
187
|
+
node=node,
|
|
188
|
+
reason=adjudication.reason,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
return adjudication
|
|
192
|
+
|
|
193
|
+
def invoke(
|
|
194
|
+
self, input: dict[str, Any], config: dict[str, Any] | None = None
|
|
195
|
+
) -> dict[str, Any]:
|
|
196
|
+
"""Synchronous version of ainvoke (not recommended for production)."""
|
|
197
|
+
import asyncio
|
|
198
|
+
|
|
199
|
+
return asyncio.run(self.ainvoke(input, config))
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _message_to_text(message: BaseMessage | Any) -> str:
|
|
203
|
+
"""Extract text content from a message."""
|
|
204
|
+
if isinstance(message, BaseMessage):
|
|
205
|
+
if isinstance(message.content, str):
|
|
206
|
+
return message.content
|
|
207
|
+
return str(message.content)
|
|
208
|
+
if isinstance(message, dict) and "content" in message:
|
|
209
|
+
return str(message["content"])
|
|
210
|
+
return str(message)
|
|
@@ -0,0 +1,454 @@
|
|
|
1
|
+
"""Sondera Harness Middleware for LangGraph."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from collections.abc import Awaitable, Callable
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from langchain.agents import AgentState
|
|
11
|
+
from langchain.agents.middleware import (
|
|
12
|
+
AgentMiddleware,
|
|
13
|
+
ModelRequest,
|
|
14
|
+
ModelResponse,
|
|
15
|
+
hook_config,
|
|
16
|
+
)
|
|
17
|
+
from langchain.messages import ToolMessage
|
|
18
|
+
from langchain.tools.tool_node import ToolCallRequest
|
|
19
|
+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
|
20
|
+
from langgraph.runtime import Runtime
|
|
21
|
+
from langgraph.types import Command
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
from langgraph.graph import END
|
|
25
|
+
except ImportError:
|
|
26
|
+
# Fallback for older versions
|
|
27
|
+
END = "__end__"
|
|
28
|
+
|
|
29
|
+
from sondera.harness import Harness
|
|
30
|
+
from sondera.types import (
|
|
31
|
+
PromptContent,
|
|
32
|
+
Role,
|
|
33
|
+
Stage,
|
|
34
|
+
ToolRequestContent,
|
|
35
|
+
ToolResponseContent,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
_LOGGER = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Strategy(str, Enum):
|
|
42
|
+
"""Strategy for handling policy violations."""
|
|
43
|
+
|
|
44
|
+
BLOCK = "block"
|
|
45
|
+
"""Jump to end immediately when a policy violation is detected."""
|
|
46
|
+
STEER = "steer"
|
|
47
|
+
"""Allow continuation with modified content when a policy violation is detected."""
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class State(AgentState):
|
|
51
|
+
"""Agent state with additional Sondera Harness-related fields."""
|
|
52
|
+
|
|
53
|
+
trajectory_id: str | None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class SonderaHarnessMiddleware(AgentMiddleware[State]):
|
|
57
|
+
"""LangGraph middleware that integrates with Sondera Harness for policy enforcement.
|
|
58
|
+
|
|
59
|
+
This middleware intercepts agent execution at key points (before/after agent,
|
|
60
|
+
model calls, tool calls) and delegates policy evaluation to the Sondera Harness
|
|
61
|
+
Service. Based on the adjudication result, it can either allow execution to
|
|
62
|
+
proceed, block and jump to end, or steer the response with modified content.
|
|
63
|
+
|
|
64
|
+
Example:
|
|
65
|
+
```python
|
|
66
|
+
from sondera.langgraph.middleware import SonderaHarnessMiddleware, Strategy
|
|
67
|
+
from sondera.harness import RemoteHarness
|
|
68
|
+
from sondera.types import Agent
|
|
69
|
+
from langchain.agents import create_agent
|
|
70
|
+
|
|
71
|
+
# Create a harness instance
|
|
72
|
+
harness = RemoteHarness(
|
|
73
|
+
endpoint="localhost:50051",
|
|
74
|
+
organization_id="my-tenant",
|
|
75
|
+
agent=Agent(
|
|
76
|
+
id="my-agent",
|
|
77
|
+
provider_id="langchain",
|
|
78
|
+
name="My Agent",
|
|
79
|
+
description="An agent with Sondera governance",
|
|
80
|
+
instruction="Be helpful",
|
|
81
|
+
tools=[],
|
|
82
|
+
),
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Create middleware with the harness
|
|
86
|
+
middleware = SonderaHarnessMiddleware(
|
|
87
|
+
harness=harness,
|
|
88
|
+
strategy=Strategy.BLOCK,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
agent = create_agent(
|
|
92
|
+
model="gpt-4o",
|
|
93
|
+
tools=[...],
|
|
94
|
+
middleware=[middleware],
|
|
95
|
+
)
|
|
96
|
+
```
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
state_schema = State
|
|
100
|
+
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
harness: Harness,
|
|
104
|
+
*,
|
|
105
|
+
strategy: Strategy = Strategy.BLOCK,
|
|
106
|
+
logger: logging.Logger | None = None,
|
|
107
|
+
) -> None:
|
|
108
|
+
"""Initialize the Sondera Harness Middleware.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
harness: The Sondera Harness instance to use
|
|
112
|
+
strategy: How to handle policy violations (BLOCK or STEER)
|
|
113
|
+
"""
|
|
114
|
+
self._harness = harness
|
|
115
|
+
self._strategy = strategy
|
|
116
|
+
self._log = logger or _LOGGER
|
|
117
|
+
super().__init__()
|
|
118
|
+
|
|
119
|
+
@hook_config(can_jump_to=["end"])
|
|
120
|
+
async def abefore_agent(
|
|
121
|
+
self, state: State, runtime: Runtime
|
|
122
|
+
) -> dict[str, Any] | None:
|
|
123
|
+
"""Execute before agent starts.
|
|
124
|
+
|
|
125
|
+
Initializes the trajectory and evaluates the user's input message
|
|
126
|
+
against policies before the agent begins processing.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
state: The current agent state containing messages
|
|
130
|
+
runtime: The LangGraph runtime
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
None to continue, or a dict with state updates (including optional jump_to)
|
|
134
|
+
"""
|
|
135
|
+
trajectory_id = state.get("trajectory_id")
|
|
136
|
+
updates = {}
|
|
137
|
+
|
|
138
|
+
if trajectory_id and trajectory_id.strip(): # Check for non-empty string
|
|
139
|
+
# Resume an existing trajectory.
|
|
140
|
+
await self._harness.resume(trajectory_id)
|
|
141
|
+
self._log.debug(
|
|
142
|
+
f"[SonderaHarness] Resumed trajectory: {self._harness.trajectory_id}"
|
|
143
|
+
)
|
|
144
|
+
else:
|
|
145
|
+
# Initialize a new trajectory if needed.
|
|
146
|
+
if self._harness.trajectory_id is None:
|
|
147
|
+
await self._harness.initialize()
|
|
148
|
+
updates["trajectory_id"] = self._harness.trajectory_id
|
|
149
|
+
self._log.debug(
|
|
150
|
+
f"[SonderaHarness] Initialized trajectory: {self._harness.trajectory_id}"
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Extract user message from state
|
|
154
|
+
user_message = _extract_last_user_message(state)
|
|
155
|
+
if user_message is None:
|
|
156
|
+
self._log.debug(
|
|
157
|
+
"[SonderaHarness] No user message found in state, skipping pre-agent check"
|
|
158
|
+
)
|
|
159
|
+
# Still return trajectory_id if we just created one
|
|
160
|
+
return updates if updates else None
|
|
161
|
+
|
|
162
|
+
content = _message_to_text(user_message)
|
|
163
|
+
self._log.debug(
|
|
164
|
+
f"[SonderaHarness] Evaluating user input for trajectory {self._harness.trajectory_id}"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
adjudication = await self._harness.adjudicate(
|
|
168
|
+
Stage.PRE_MODEL,
|
|
169
|
+
Role.USER,
|
|
170
|
+
PromptContent(text=content),
|
|
171
|
+
)
|
|
172
|
+
self._log.info(
|
|
173
|
+
f"[SonderaHarness] Before Agent Adjudication for trajectory {self._harness.trajectory_id}"
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
if adjudication.is_denied:
|
|
177
|
+
self._log.warning(
|
|
178
|
+
f"[SonderaHarness] Policy violation detected (strategy={self._strategy.value}): "
|
|
179
|
+
f"{adjudication.reason}"
|
|
180
|
+
)
|
|
181
|
+
if self._strategy == Strategy.BLOCK:
|
|
182
|
+
# BLOCK: Jump to end immediately with policy message
|
|
183
|
+
return {
|
|
184
|
+
"messages": [AIMessage(content=adjudication.reason)],
|
|
185
|
+
"jump_to": "end",
|
|
186
|
+
**updates, # Include trajectory_id in the response
|
|
187
|
+
}
|
|
188
|
+
# STEER: Replace user message with policy guidance and continue
|
|
189
|
+
return {
|
|
190
|
+
"messages": [
|
|
191
|
+
AIMessage(
|
|
192
|
+
content=f"Policy violation in user message: {adjudication.reason}"
|
|
193
|
+
)
|
|
194
|
+
],
|
|
195
|
+
**updates, # Include trajectory_id in the response
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
# Return trajectory_id if we just created one
|
|
199
|
+
return updates if updates else None
|
|
200
|
+
|
|
201
|
+
async def awrap_model_call(
|
|
202
|
+
self,
|
|
203
|
+
request: ModelRequest,
|
|
204
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
205
|
+
) -> ModelResponse:
|
|
206
|
+
"""Wrap model calls with policy evaluation.
|
|
207
|
+
|
|
208
|
+
Evaluates the model request before calling the model, then evaluates
|
|
209
|
+
the model's response after it returns.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
request: The model request containing messages and configuration
|
|
213
|
+
handler: The handler function to call the actual model
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
The model response, potentially modified based on policy
|
|
217
|
+
"""
|
|
218
|
+
if isinstance(request.messages[-1], AIMessage):
|
|
219
|
+
# Last message is an AIMessage, so we need to adjudicate it. HumanMessage was checked in abefore_agent.
|
|
220
|
+
_LOGGER.debug(
|
|
221
|
+
f"[SonderaHarness] Pre-model check for trajectory {self._harness.trajectory_id} {request.messages}"
|
|
222
|
+
)
|
|
223
|
+
pre_adjudication = await self._harness.adjudicate(
|
|
224
|
+
Stage.PRE_MODEL,
|
|
225
|
+
Role.MODEL,
|
|
226
|
+
PromptContent(text=_message_to_text(request.messages[-1])),
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
if pre_adjudication.is_denied:
|
|
230
|
+
_LOGGER.warning(
|
|
231
|
+
f"[SonderaHarness] Pre-model policy violation (strategy={self._strategy.value}): "
|
|
232
|
+
f"{pre_adjudication.reason}"
|
|
233
|
+
)
|
|
234
|
+
message = AIMessage(
|
|
235
|
+
content=f"Replaced message due to policy violation: {pre_adjudication.reason}"
|
|
236
|
+
)
|
|
237
|
+
if self._strategy == Strategy.STEER:
|
|
238
|
+
# STEER: Replace the last message with the policy message
|
|
239
|
+
request.messages[-1] = message
|
|
240
|
+
else:
|
|
241
|
+
# BLOCK: Return early with the policy message
|
|
242
|
+
return ModelResponse(
|
|
243
|
+
result=[message],
|
|
244
|
+
structured_response=None,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Call the actual model
|
|
248
|
+
response: ModelResponse = await handler(request)
|
|
249
|
+
|
|
250
|
+
# Post-model check on each AI message in the response
|
|
251
|
+
sanitized_messages: list[BaseMessage] = []
|
|
252
|
+
for message in response.result:
|
|
253
|
+
if isinstance(message, AIMessage):
|
|
254
|
+
post_adjudication = await self._harness.adjudicate(
|
|
255
|
+
Stage.POST_MODEL,
|
|
256
|
+
Role.MODEL,
|
|
257
|
+
PromptContent(text=message.text),
|
|
258
|
+
)
|
|
259
|
+
self._log.info(
|
|
260
|
+
f"[SonderaHarness] Post-model Adjudication for trajectory {self._harness.trajectory_id}"
|
|
261
|
+
)
|
|
262
|
+
if post_adjudication.is_denied:
|
|
263
|
+
self._log.warning(
|
|
264
|
+
f"[SonderaHarness] Post-model policy violation (strategy={self._strategy.value}): "
|
|
265
|
+
f"{post_adjudication.reason}"
|
|
266
|
+
)
|
|
267
|
+
message = AIMessage(
|
|
268
|
+
content=f"Replaced message due to policy violation: {post_adjudication.reason}"
|
|
269
|
+
)
|
|
270
|
+
if self._strategy == Strategy.STEER:
|
|
271
|
+
# STEER: Replace the message with the policy message
|
|
272
|
+
sanitized_messages.append(message)
|
|
273
|
+
else:
|
|
274
|
+
# BLOCK: Return early with the policy message
|
|
275
|
+
return ModelResponse(
|
|
276
|
+
result=[message],
|
|
277
|
+
structured_response=response.structured_response,
|
|
278
|
+
)
|
|
279
|
+
else:
|
|
280
|
+
sanitized_messages.append(message)
|
|
281
|
+
else:
|
|
282
|
+
self._log.debug(
|
|
283
|
+
f"[SonderaHarness] Non-AIMessage in response: {message} in trajectory {self._harness.trajectory_id}"
|
|
284
|
+
)
|
|
285
|
+
sanitized_messages.append(message)
|
|
286
|
+
|
|
287
|
+
return ModelResponse(
|
|
288
|
+
result=sanitized_messages,
|
|
289
|
+
structured_response=response.structured_response,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
async def awrap_tool_call(
|
|
293
|
+
self,
|
|
294
|
+
request: ToolCallRequest,
|
|
295
|
+
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
296
|
+
) -> ToolMessage | Command:
|
|
297
|
+
"""Wrap tool calls with policy evaluation.
|
|
298
|
+
|
|
299
|
+
Evaluates the tool request before execution, then evaluates
|
|
300
|
+
the tool's response after it returns.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
request: The tool call request containing tool name and arguments
|
|
304
|
+
handler: The handler function to execute the actual tool
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
The tool response, potentially modified based on policy
|
|
308
|
+
"""
|
|
309
|
+
tool_name = request.tool_call.get("name", "unknown_tool")
|
|
310
|
+
tool_args = request.tool_call.get("args", {})
|
|
311
|
+
tool_call_id = request.tool_call.get("id", "")
|
|
312
|
+
|
|
313
|
+
# Pre-tool check
|
|
314
|
+
self._log.debug(
|
|
315
|
+
f"[SonderaHarness] Pre-tool check for {tool_name} in trajectory {self._harness.trajectory_id}"
|
|
316
|
+
)
|
|
317
|
+
pre_adjudication = await self._harness.adjudicate(
|
|
318
|
+
Stage.PRE_TOOL,
|
|
319
|
+
Role.TOOL,
|
|
320
|
+
ToolRequestContent(tool_id=tool_name, args=tool_args),
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
self._log.info(
|
|
324
|
+
f"[SonderaHarness] Before Tool Adjudication for trajectory {self._harness.trajectory_id}"
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
if pre_adjudication.is_denied:
|
|
328
|
+
self._log.warning(
|
|
329
|
+
f"[SonderaHarness] Pre-tool policy violation for {tool_name} "
|
|
330
|
+
f"(strategy={self._strategy.value}): {pre_adjudication.reason}"
|
|
331
|
+
)
|
|
332
|
+
if self._strategy == Strategy.BLOCK:
|
|
333
|
+
# BLOCK: Jump to end using Command
|
|
334
|
+
return Command(
|
|
335
|
+
goto=END,
|
|
336
|
+
update={
|
|
337
|
+
"messages": [
|
|
338
|
+
ToolMessage(
|
|
339
|
+
content=f"Tool execution was blocked. {pre_adjudication.reason}",
|
|
340
|
+
tool_call_id=tool_call_id,
|
|
341
|
+
name=tool_name,
|
|
342
|
+
)
|
|
343
|
+
]
|
|
344
|
+
},
|
|
345
|
+
)
|
|
346
|
+
# STEER: Return tool message with policy violation instead of allowing execution
|
|
347
|
+
return ToolMessage(
|
|
348
|
+
content=f"Tool execution modified due to policy concern: {pre_adjudication.reason}",
|
|
349
|
+
tool_call_id=tool_call_id,
|
|
350
|
+
name=tool_name,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
# Execute the actual tool
|
|
354
|
+
result = await handler(request)
|
|
355
|
+
|
|
356
|
+
# Post-tool check
|
|
357
|
+
if isinstance(result, ToolMessage):
|
|
358
|
+
output_text = _tool_message_to_text(result)
|
|
359
|
+
|
|
360
|
+
post_adjudication = await self._harness.adjudicate(
|
|
361
|
+
Stage.POST_TOOL,
|
|
362
|
+
Role.TOOL,
|
|
363
|
+
ToolResponseContent(tool_id=tool_name, response=output_text),
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
self._log.info(
|
|
367
|
+
f"[SonderaHarness] After Tool Adjudication for trajectory {self._harness.trajectory_id}"
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
if post_adjudication.is_denied:
|
|
371
|
+
self._log.warning(
|
|
372
|
+
f"[SonderaHarness] Post-tool policy violation for {tool_name} "
|
|
373
|
+
f"(strategy={self._strategy.value}): {post_adjudication.reason}"
|
|
374
|
+
)
|
|
375
|
+
if self._strategy == Strategy.BLOCK:
|
|
376
|
+
# BLOCK: Jump to end using Command
|
|
377
|
+
return Command(
|
|
378
|
+
goto=END,
|
|
379
|
+
update={
|
|
380
|
+
"messages": [
|
|
381
|
+
ToolMessage(
|
|
382
|
+
content=f"Tool result was blocked. {post_adjudication.reason}",
|
|
383
|
+
tool_call_id=tool_call_id,
|
|
384
|
+
name=tool_name,
|
|
385
|
+
)
|
|
386
|
+
]
|
|
387
|
+
},
|
|
388
|
+
)
|
|
389
|
+
# STEER: Return modified ToolMessage with policy violation message
|
|
390
|
+
return ToolMessage(
|
|
391
|
+
content=f"Tool result was modified. {post_adjudication.reason}",
|
|
392
|
+
tool_call_id=tool_call_id,
|
|
393
|
+
name=tool_name,
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
return result
|
|
397
|
+
|
|
398
|
+
async def aafter_agent(
|
|
399
|
+
self, state: AgentState, runtime: Runtime
|
|
400
|
+
) -> dict[str, Any] | None:
|
|
401
|
+
"""Execute after agent completes.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
state: The final agent state containing messages
|
|
405
|
+
runtime: The LangGraph runtime
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
None to continue, or a dict with state updates
|
|
409
|
+
"""
|
|
410
|
+
# Finalize the trajectory
|
|
411
|
+
trajectory_id = self._harness.trajectory_id
|
|
412
|
+
await self._harness.finalize()
|
|
413
|
+
self._log.info(f"[SonderaHarness] Trajectory finalized: {trajectory_id}")
|
|
414
|
+
|
|
415
|
+
# Preserve trajectory_id in final state for next conversation
|
|
416
|
+
return {"trajectory_id": trajectory_id} if trajectory_id else None
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def _extract_last_user_message(state: AgentState) -> BaseMessage | None:
|
|
420
|
+
"""Extract the last user message from agent state."""
|
|
421
|
+
messages = state.get("messages", [])
|
|
422
|
+
if not messages:
|
|
423
|
+
return None
|
|
424
|
+
|
|
425
|
+
# Look for the last HumanMessage
|
|
426
|
+
for message in reversed(messages):
|
|
427
|
+
if isinstance(message, HumanMessage):
|
|
428
|
+
return message
|
|
429
|
+
if isinstance(message, dict) and message.get("role") == "user":
|
|
430
|
+
return HumanMessage(content=message.get("content", ""))
|
|
431
|
+
|
|
432
|
+
# Fallback to last message if it looks like user input
|
|
433
|
+
last = messages[-1]
|
|
434
|
+
if isinstance(last, dict):
|
|
435
|
+
return HumanMessage(content=last.get("content", ""))
|
|
436
|
+
return None
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def _message_to_text(message: BaseMessage) -> str:
|
|
440
|
+
"""Convert a message to text content."""
|
|
441
|
+
if isinstance(message.content, str):
|
|
442
|
+
return message.content
|
|
443
|
+
if isinstance(message.content, list):
|
|
444
|
+
return " ".join(str(chunk) for chunk in message.content)
|
|
445
|
+
return str(message.content)
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def _tool_message_to_text(message: ToolMessage) -> str:
|
|
449
|
+
"""Convert a tool message to text content."""
|
|
450
|
+
if isinstance(message.content, str):
|
|
451
|
+
return message.content
|
|
452
|
+
if isinstance(message.content, list):
|
|
453
|
+
return " ".join(str(chunk) for chunk in message.content)
|
|
454
|
+
return str(message.content)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
4
|
+
# source: google/protobuf/any.proto
|
|
5
|
+
# Protobuf Python Version: 6.31.1
|
|
6
|
+
"""Generated protocol buffer code."""
|
|
7
|
+
from google.protobuf import descriptor as _descriptor
|
|
8
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
10
|
+
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
6,
|
|
15
|
+
31,
|
|
16
|
+
1,
|
|
17
|
+
'',
|
|
18
|
+
'google/protobuf/any.proto'
|
|
19
|
+
)
|
|
20
|
+
# @@protoc_insertion_point(imports)
|
|
21
|
+
|
|
22
|
+
_sym_db = _symbol_database.Default()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19google/protobuf/any.proto\x12\x0fgoogle.protobuf\"&\n\x03\x41ny\x12\x10\n\x08type_url\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x42v\n\x13\x63om.google.protobufB\x08\x41nyProtoP\x01Z,google.golang.org/protobuf/types/known/anypb\xa2\x02\x03GPB\xaa\x02\x1eGoogle.Protobuf.WellKnownTypesb\x06proto3')
|
|
28
|
+
|
|
29
|
+
_globals = globals()
|
|
30
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
31
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.protobuf.any_pb2', _globals)
|
|
32
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
33
|
+
_globals['DESCRIPTOR']._loaded_options = None
|
|
34
|
+
_globals['DESCRIPTOR']._serialized_options = b'\n\023com.google.protobufB\010AnyProtoP\001Z,google.golang.org/protobuf/types/known/anypb\242\002\003GPB\252\002\036Google.Protobuf.WellKnownTypes'
|
|
35
|
+
_globals['_ANY']._serialized_start=46
|
|
36
|
+
_globals['_ANY']._serialized_end=84
|
|
37
|
+
# @@protoc_insertion_point(module_scope)
|