agentrust-py 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrust/__init__.py +72 -0
- agentrust_py-0.0.3.dist-info/METADATA +193 -0
- agentrust_py-0.0.3.dist-info/RECORD +29 -0
- agentrust_py-0.0.3.dist-info/WHEEL +4 -0
- agentrust_py-0.0.3.dist-info/entry_points.txt +2 -0
- agentrust_py-0.0.3.dist-info/licenses/LICENSE +177 -0
- agentrust_sdk/__init__.py +124 -0
- agentrust_sdk/adapters/__init__.py +1 -0
- agentrust_sdk/adapters/autogen.py +235 -0
- agentrust_sdk/adapters/claude_agents.py +225 -0
- agentrust_sdk/adapters/crewai.py +98 -0
- agentrust_sdk/adapters/langgraph.py +109 -0
- agentrust_sdk/adapters/mcp.py +193 -0
- agentrust_sdk/adapters/openai_agents.py +263 -0
- agentrust_sdk/auth.py +192 -0
- agentrust_sdk/auto.py +397 -0
- agentrust_sdk/autoload.py +95 -0
- agentrust_sdk/cli.py +736 -0
- agentrust_sdk/client.py +790 -0
- agentrust_sdk/config.py +192 -0
- agentrust_sdk/decorator.py +276 -0
- agentrust_sdk/embedded.py +428 -0
- agentrust_sdk/hooks.py +461 -0
- agentrust_sdk/models.py +81 -0
- agentrust_sdk/py.typed +0 -0
- agentrust_sdk/queue_replay.py +204 -0
- agentrust_sdk/tiers.py +180 -0
- agentrust_sdk/version_negotiation.py +290 -0
- agentrust_sdk/webhooks.py +782 -0
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP (Model Context Protocol) adapter — AgentTrust tool-call governance.
|
|
3
|
+
|
|
4
|
+
Architecture doc §3: "MCP wrapper" as a key Adapter Layer capability.
|
|
5
|
+
|
|
6
|
+
Wraps an MCP server so every tool execution is validated by AgentTrust
|
|
7
|
+
before the result is returned to the calling model.
|
|
8
|
+
|
|
9
|
+
Two usage patterns:
|
|
10
|
+
|
|
11
|
+
1. Tool guard decorator (wrap individual tools):
|
|
12
|
+
|
|
13
|
+
from agentrust_sdk.adapters.mcp import mcp_tool_guard
|
|
14
|
+
|
|
15
|
+
@mcp_tool_guard(agent_id="my-mcp-server")
|
|
16
|
+
async def read_file(path: str) -> str:
|
|
17
|
+
return open(path).read()
|
|
18
|
+
|
|
19
|
+
2. MCPServerGuard (wrap a FastMCP / mcp.Server instance):
|
|
20
|
+
|
|
21
|
+
from mcp.server.fastmcp import FastMCP
|
|
22
|
+
from agentrust_sdk.adapters.mcp import MCPServerGuard
|
|
23
|
+
|
|
24
|
+
mcp = FastMCP("my-server")
|
|
25
|
+
guard = MCPServerGuard(mcp, agent_id="my-mcp-server")
|
|
26
|
+
# All @mcp.tool() handlers are now governed by AgentTrust.
|
|
27
|
+
|
|
28
|
+
Requires Developer tier (tier check at construction time).
|
|
29
|
+
"""
|
|
30
|
+
from __future__ import annotations
|
|
31
|
+
|
|
32
|
+
import functools
|
|
33
|
+
import inspect
|
|
34
|
+
import logging
|
|
35
|
+
import time
|
|
36
|
+
from typing import Any, Callable
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _check_tier(api_key: str | None = None) -> None:
|
|
42
|
+
from agentrust_sdk.auth import resolve_key
|
|
43
|
+
from agentrust_sdk.tiers import Capability, is_allowed, UPGRADE_MESSAGES
|
|
44
|
+
info = resolve_key(api_key)
|
|
45
|
+
cap = Capability.MCP_ADAPTER if hasattr(Capability, "MCP_ADAPTER") else None
|
|
46
|
+
if cap and not is_allowed(cap, info.tier):
|
|
47
|
+
msg = UPGRADE_MESSAGES.get(cap, "MCP adapter requires Developer tier or higher.")
|
|
48
|
+
raise RuntimeError(f"[AgentTrust] {msg}")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def mcp_tool_guard(
|
|
52
|
+
agent_id: str,
|
|
53
|
+
*,
|
|
54
|
+
base_url: str = "http://localhost:8000",
|
|
55
|
+
api_key: str | None = None,
|
|
56
|
+
user: str = "mcp",
|
|
57
|
+
block_on_block: bool = True,
|
|
58
|
+
framework: str = "MCP",
|
|
59
|
+
) -> Callable:
|
|
60
|
+
"""Decorator that validates every MCP tool call via AgentTrust.
|
|
61
|
+
|
|
62
|
+
Works with both sync and async tool functions.
|
|
63
|
+
"""
|
|
64
|
+
def decorator(fn: Callable) -> Callable:
|
|
65
|
+
is_async = inspect.iscoroutinefunction(fn)
|
|
66
|
+
|
|
67
|
+
async def _validate(tool_name: str, args: dict, result: Any, latency_ms: float) -> None:
|
|
68
|
+
from agentrust_sdk.client import AgentTrustClient
|
|
69
|
+
from agentrust_sdk.decorator import _check_decision, BlockedError
|
|
70
|
+
|
|
71
|
+
output = result if isinstance(result, dict) else {"result": str(result)[:2000]}
|
|
72
|
+
try:
|
|
73
|
+
with AgentTrustClient(base_url=base_url, api_key=api_key) as client:
|
|
74
|
+
resp = client.validate(
|
|
75
|
+
agent_id=agent_id,
|
|
76
|
+
user=user,
|
|
77
|
+
input=tool_name,
|
|
78
|
+
output={**output, "_tool_args": args},
|
|
79
|
+
framework=framework,
|
|
80
|
+
tools_called=[{"name": tool_name, "args": args, "result": result}],
|
|
81
|
+
latency_ms=latency_ms,
|
|
82
|
+
)
|
|
83
|
+
_check_decision(resp, block_on_block, False)
|
|
84
|
+
except BlockedError:
|
|
85
|
+
raise
|
|
86
|
+
except Exception as exc:
|
|
87
|
+
logger.warning("[AgentTrust] MCP tool guard failed (non-fatal): %s", exc)
|
|
88
|
+
|
|
89
|
+
if is_async:
|
|
90
|
+
@functools.wraps(fn)
|
|
91
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
92
|
+
t0 = time.perf_counter()
|
|
93
|
+
result = await fn(*args, **kwargs)
|
|
94
|
+
latency_ms = (time.perf_counter() - t0) * 1000
|
|
95
|
+
call_args = inspect.signature(fn).bind(*args, **kwargs).arguments
|
|
96
|
+
await _validate(fn.__name__, dict(call_args), result, latency_ms)
|
|
97
|
+
return result
|
|
98
|
+
return async_wrapper
|
|
99
|
+
else:
|
|
100
|
+
@functools.wraps(fn)
|
|
101
|
+
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
102
|
+
import asyncio
|
|
103
|
+
t0 = time.perf_counter()
|
|
104
|
+
result = fn(*args, **kwargs)
|
|
105
|
+
latency_ms = (time.perf_counter() - t0) * 1000
|
|
106
|
+
call_args = inspect.signature(fn).bind(*args, **kwargs).arguments
|
|
107
|
+
try:
|
|
108
|
+
loop = asyncio.get_event_loop()
|
|
109
|
+
if loop.is_running():
|
|
110
|
+
loop.create_task(_validate(fn.__name__, dict(call_args), result, latency_ms))
|
|
111
|
+
else:
|
|
112
|
+
loop.run_until_complete(_validate(fn.__name__, dict(call_args), result, latency_ms))
|
|
113
|
+
except Exception as exc:
|
|
114
|
+
logger.warning("[AgentTrust] MCP sync validate error: %s", exc)
|
|
115
|
+
return result
|
|
116
|
+
return sync_wrapper
|
|
117
|
+
|
|
118
|
+
return decorator
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class MCPServerGuard:
|
|
122
|
+
"""Wraps a FastMCP / mcp.Server instance and governs all registered tools.
|
|
123
|
+
|
|
124
|
+
Usage::
|
|
125
|
+
|
|
126
|
+
from mcp.server.fastmcp import FastMCP
|
|
127
|
+
from agentrust_sdk.adapters.mcp import MCPServerGuard
|
|
128
|
+
|
|
129
|
+
mcp = FastMCP("payments-server")
|
|
130
|
+
|
|
131
|
+
@mcp.tool()
|
|
132
|
+
async def process_payment(amount: float, account: str) -> dict:
|
|
133
|
+
...
|
|
134
|
+
|
|
135
|
+
guard = MCPServerGuard(mcp, agent_id="payments-mcp")
|
|
136
|
+
# process_payment is now governed — AgentTrust validates every call.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
server: Any,
|
|
142
|
+
agent_id: str,
|
|
143
|
+
*,
|
|
144
|
+
base_url: str = "http://localhost:8000",
|
|
145
|
+
api_key: str | None = None,
|
|
146
|
+
user: str = "mcp",
|
|
147
|
+
block_on_block: bool = True,
|
|
148
|
+
) -> None:
|
|
149
|
+
_check_tier(api_key)
|
|
150
|
+
self._server = server
|
|
151
|
+
self._agent_id = agent_id
|
|
152
|
+
self._base_url = base_url
|
|
153
|
+
self._api_key = api_key
|
|
154
|
+
self._user = user
|
|
155
|
+
self._block_on_block = block_on_block
|
|
156
|
+
self._wrap_tools()
|
|
157
|
+
|
|
158
|
+
def _wrap_tools(self) -> None:
|
|
159
|
+
"""Monkey-patch all registered tool handlers on the server."""
|
|
160
|
+
tools_attr = None
|
|
161
|
+
for attr in ("_tool_handlers", "tools", "_tools", "_handlers"):
|
|
162
|
+
if hasattr(self._server, attr):
|
|
163
|
+
tools_attr = getattr(self._server, attr)
|
|
164
|
+
break
|
|
165
|
+
|
|
166
|
+
if tools_attr is None:
|
|
167
|
+
logger.warning(
|
|
168
|
+
"[AgentTrust] MCPServerGuard: could not locate tool registry on %s. "
|
|
169
|
+
"Use @mcp_tool_guard decorator instead.",
|
|
170
|
+
type(self._server).__name__,
|
|
171
|
+
)
|
|
172
|
+
return
|
|
173
|
+
|
|
174
|
+
guard = mcp_tool_guard(
|
|
175
|
+
self._agent_id,
|
|
176
|
+
base_url=self._base_url,
|
|
177
|
+
api_key=self._api_key,
|
|
178
|
+
user=self._user,
|
|
179
|
+
block_on_block=self._block_on_block,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
if isinstance(tools_attr, dict):
|
|
183
|
+
for name, fn in list(tools_attr.items()):
|
|
184
|
+
tools_attr[name] = guard(fn)
|
|
185
|
+
elif isinstance(tools_attr, list):
|
|
186
|
+
for i, fn in enumerate(tools_attr):
|
|
187
|
+
tools_attr[i] = guard(fn)
|
|
188
|
+
|
|
189
|
+
logger.info(
|
|
190
|
+
"[AgentTrust] MCPServerGuard: wrapped %d tool(s) on agent '%s'",
|
|
191
|
+
len(tools_attr) if hasattr(tools_attr, "__len__") else -1,
|
|
192
|
+
self._agent_id,
|
|
193
|
+
)
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OpenAI Agents SDK adapter — AgentTrust governance for OpenAI Agents (Swarm / Agents SDK).
|
|
3
|
+
|
|
4
|
+
Architecture doc (two-pager): "Connects to … OpenAI Agents."
|
|
5
|
+
|
|
6
|
+
Supports three patterns:
|
|
7
|
+
|
|
8
|
+
1. @agent_guard decorator — wrap a single Agent run() call:
|
|
9
|
+
|
|
10
|
+
from agents import Agent, Runner
|
|
11
|
+
from agentrust_sdk.adapters.openai_agents import agent_guard
|
|
12
|
+
|
|
13
|
+
@agent_guard(agent_id="support-bot")
|
|
14
|
+
async def run_support(query: str) -> str:
|
|
15
|
+
result = await Runner.run(my_agent, query)
|
|
16
|
+
return result.final_output
|
|
17
|
+
|
|
18
|
+
2. OpenAIAgentsGuard — wrap a Runner:
|
|
19
|
+
|
|
20
|
+
from agentrust_sdk.adapters.openai_agents import OpenAIAgentsGuard
|
|
21
|
+
|
|
22
|
+
guard = OpenAIAgentsGuard(agent_id="support-bot")
|
|
23
|
+
result = await guard.run(agent, "process this query")
|
|
24
|
+
|
|
25
|
+
3. Swarm (legacy openai-swarm) wrapper:
|
|
26
|
+
|
|
27
|
+
from agentrust_sdk.adapters.openai_agents import SwarmGuard
|
|
28
|
+
|
|
29
|
+
guard = SwarmGuard(agent_id="swarm-agent")
|
|
30
|
+
response = guard.run(swarm_client, agent, messages)
|
|
31
|
+
|
|
32
|
+
Requires Team tier.
|
|
33
|
+
"""
|
|
34
|
+
from __future__ import annotations
|
|
35
|
+
|
|
36
|
+
import functools
|
|
37
|
+
import inspect
|
|
38
|
+
import logging
|
|
39
|
+
import time
|
|
40
|
+
from typing import Any
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _check_tier(api_key: str | None = None) -> None:
|
|
46
|
+
from agentrust_sdk.auth import resolve_key
|
|
47
|
+
from agentrust_sdk.tiers import Capability, is_allowed
|
|
48
|
+
info = resolve_key(api_key)
|
|
49
|
+
if not is_allowed(Capability.CREWAI_ADAPTER, info.tier):
|
|
50
|
+
raise RuntimeError("[AgentTrust] OpenAI Agents adapter requires Team tier or higher.")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _validate(
|
|
54
|
+
agent_id: str,
|
|
55
|
+
base_url: str,
|
|
56
|
+
api_key: str | None,
|
|
57
|
+
user: str,
|
|
58
|
+
input_text: str,
|
|
59
|
+
output: dict,
|
|
60
|
+
framework: str,
|
|
61
|
+
block_on_block: bool,
|
|
62
|
+
latency_ms: float = 0.0,
|
|
63
|
+
tool_calls: list | None = None,
|
|
64
|
+
) -> None:
|
|
65
|
+
from agentrust_sdk.client import AgentTrustClient
|
|
66
|
+
from agentrust_sdk.decorator import _check_decision, BlockedError
|
|
67
|
+
try:
|
|
68
|
+
with AgentTrustClient(base_url=base_url, api_key=api_key) as client:
|
|
69
|
+
resp = client.validate(
|
|
70
|
+
agent_id=agent_id,
|
|
71
|
+
user=user,
|
|
72
|
+
input=input_text[:1000],
|
|
73
|
+
output=output,
|
|
74
|
+
framework=framework,
|
|
75
|
+
tools_called=tool_calls or [],
|
|
76
|
+
latency_ms=latency_ms,
|
|
77
|
+
)
|
|
78
|
+
_check_decision(resp, block_on_block, False)
|
|
79
|
+
except BlockedError:
|
|
80
|
+
raise
|
|
81
|
+
except Exception as exc:
|
|
82
|
+
logger.warning("[AgentTrust] OpenAI Agents validate failed (non-fatal): %s", exc)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _extract_run_output(result: Any) -> tuple[str, list]:
|
|
86
|
+
"""Extract final_output and tool calls from an openai-agents RunResult."""
|
|
87
|
+
final_output = ""
|
|
88
|
+
tool_calls: list[dict] = []
|
|
89
|
+
|
|
90
|
+
if hasattr(result, "final_output"):
|
|
91
|
+
final_output = str(result.final_output or "")
|
|
92
|
+
elif isinstance(result, str):
|
|
93
|
+
final_output = result
|
|
94
|
+
|
|
95
|
+
# Extract tool calls from new_items (RunResult in openai-agents SDK)
|
|
96
|
+
if hasattr(result, "new_items"):
|
|
97
|
+
for item in result.new_items or []:
|
|
98
|
+
item_type = getattr(item, "type", None)
|
|
99
|
+
if item_type == "tool_call_item" or hasattr(item, "tool_call"):
|
|
100
|
+
call = getattr(item, "tool_call", item)
|
|
101
|
+
tool_calls.append({
|
|
102
|
+
"name": getattr(call, "name", ""),
|
|
103
|
+
"args": getattr(call, "arguments", {}),
|
|
104
|
+
})
|
|
105
|
+
|
|
106
|
+
return final_output, tool_calls
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def agent_guard(
|
|
110
|
+
agent_id: str,
|
|
111
|
+
*,
|
|
112
|
+
base_url: str = "http://localhost:8000",
|
|
113
|
+
api_key: str | None = None,
|
|
114
|
+
user: str = "openai-agent",
|
|
115
|
+
block_on_block: bool = True,
|
|
116
|
+
framework: str = "OpenAI-Agents",
|
|
117
|
+
) -> Any:
|
|
118
|
+
"""Decorator for async agent runner functions."""
|
|
119
|
+
_check_tier(api_key)
|
|
120
|
+
|
|
121
|
+
def decorator(fn: Any) -> Any:
|
|
122
|
+
is_async = inspect.iscoroutinefunction(fn)
|
|
123
|
+
|
|
124
|
+
if is_async:
|
|
125
|
+
@functools.wraps(fn)
|
|
126
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
127
|
+
t0 = time.perf_counter()
|
|
128
|
+
result = await fn(*args, **kwargs)
|
|
129
|
+
latency_ms = (time.perf_counter() - t0) * 1000
|
|
130
|
+
input_text = str(args[0]) if args else ""
|
|
131
|
+
final_output, tool_calls = _extract_run_output(result)
|
|
132
|
+
_validate(
|
|
133
|
+
agent_id, base_url, api_key, user, input_text,
|
|
134
|
+
{"final_output": final_output[:2000]},
|
|
135
|
+
framework, block_on_block, latency_ms, tool_calls,
|
|
136
|
+
)
|
|
137
|
+
return result
|
|
138
|
+
return async_wrapper
|
|
139
|
+
else:
|
|
140
|
+
@functools.wraps(fn)
|
|
141
|
+
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
142
|
+
t0 = time.perf_counter()
|
|
143
|
+
result = fn(*args, **kwargs)
|
|
144
|
+
latency_ms = (time.perf_counter() - t0) * 1000
|
|
145
|
+
input_text = str(args[0]) if args else ""
|
|
146
|
+
final_output, tool_calls = _extract_run_output(result)
|
|
147
|
+
_validate(
|
|
148
|
+
agent_id, base_url, api_key, user, input_text,
|
|
149
|
+
{"final_output": final_output[:2000]},
|
|
150
|
+
framework, block_on_block, latency_ms, tool_calls,
|
|
151
|
+
)
|
|
152
|
+
return result
|
|
153
|
+
return sync_wrapper
|
|
154
|
+
|
|
155
|
+
return decorator
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class OpenAIAgentsGuard:
|
|
159
|
+
"""
|
|
160
|
+
Wraps openai-agents Runner.run() with AgentTrust governance.
|
|
161
|
+
|
|
162
|
+
Usage::
|
|
163
|
+
|
|
164
|
+
from agents import Agent
|
|
165
|
+
from agentrust_sdk.adapters.openai_agents import OpenAIAgentsGuard
|
|
166
|
+
|
|
167
|
+
guard = OpenAIAgentsGuard(agent_id="invoice-agent")
|
|
168
|
+
result = await guard.run(my_agent, "process invoice #42")
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
agent_id: str,
|
|
174
|
+
*,
|
|
175
|
+
base_url: str = "http://localhost:8000",
|
|
176
|
+
api_key: str | None = None,
|
|
177
|
+
user: str = "openai-agent",
|
|
178
|
+
block_on_block: bool = True,
|
|
179
|
+
framework: str = "OpenAI-Agents",
|
|
180
|
+
) -> None:
|
|
181
|
+
_check_tier(api_key)
|
|
182
|
+
self._agent_id = agent_id
|
|
183
|
+
self._base_url = base_url
|
|
184
|
+
self._api_key = api_key
|
|
185
|
+
self._user = user
|
|
186
|
+
self._block_on_block = block_on_block
|
|
187
|
+
self._framework = framework
|
|
188
|
+
|
|
189
|
+
async def run(self, agent: Any, input_text: str, **kwargs: Any) -> Any:
|
|
190
|
+
"""Run *agent* via openai-agents Runner and validate the result."""
|
|
191
|
+
try:
|
|
192
|
+
from agents import Runner
|
|
193
|
+
except ImportError:
|
|
194
|
+
raise RuntimeError(
|
|
195
|
+
"openai-agents package not installed. pip install openai-agents"
|
|
196
|
+
)
|
|
197
|
+
t0 = time.perf_counter()
|
|
198
|
+
result = await Runner.run(agent, input_text, **kwargs)
|
|
199
|
+
latency_ms = (time.perf_counter() - t0) * 1000
|
|
200
|
+
|
|
201
|
+
final_output, tool_calls = _extract_run_output(result)
|
|
202
|
+
_validate(
|
|
203
|
+
self._agent_id, self._base_url, self._api_key, self._user,
|
|
204
|
+
input_text, {"final_output": final_output[:2000]},
|
|
205
|
+
self._framework, self._block_on_block, latency_ms, tool_calls,
|
|
206
|
+
)
|
|
207
|
+
return result
|
|
208
|
+
|
|
209
|
+
def run_sync(self, agent: Any, input_text: str, **kwargs: Any) -> Any:
|
|
210
|
+
"""Synchronous version for non-async code."""
|
|
211
|
+
import asyncio
|
|
212
|
+
return asyncio.run(self.run(agent, input_text, **kwargs))
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class SwarmGuard:
|
|
216
|
+
"""
|
|
217
|
+
Wraps legacy openai-swarm client.run() with AgentTrust governance.
|
|
218
|
+
|
|
219
|
+
Usage::
|
|
220
|
+
|
|
221
|
+
from swarm import Swarm
|
|
222
|
+
from agentrust_sdk.adapters.openai_agents import SwarmGuard
|
|
223
|
+
|
|
224
|
+
client = Swarm()
|
|
225
|
+
guard = SwarmGuard(agent_id="swarm-agent")
|
|
226
|
+
response = guard.run(client, agent, [{"role": "user", "content": query}])
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
def __init__(
|
|
230
|
+
self,
|
|
231
|
+
agent_id: str,
|
|
232
|
+
*,
|
|
233
|
+
base_url: str = "http://localhost:8000",
|
|
234
|
+
api_key: str | None = None,
|
|
235
|
+
user: str = "swarm",
|
|
236
|
+
block_on_block: bool = True,
|
|
237
|
+
framework: str = "OpenAI-Swarm",
|
|
238
|
+
) -> None:
|
|
239
|
+
_check_tier(api_key)
|
|
240
|
+
self._agent_id = agent_id
|
|
241
|
+
self._base_url = base_url
|
|
242
|
+
self._api_key = api_key
|
|
243
|
+
self._user = user
|
|
244
|
+
self._block_on_block = block_on_block
|
|
245
|
+
self._framework = framework
|
|
246
|
+
|
|
247
|
+
def run(self, swarm_client: Any, agent: Any, messages: list, **kwargs: Any) -> Any:
|
|
248
|
+
t0 = time.perf_counter()
|
|
249
|
+
response = swarm_client.run(agent, messages, **kwargs)
|
|
250
|
+
latency_ms = (time.perf_counter() - t0) * 1000
|
|
251
|
+
|
|
252
|
+
last_user = next(
|
|
253
|
+
(m.get("content", "") for m in reversed(messages) if m.get("role") == "user"), ""
|
|
254
|
+
)
|
|
255
|
+
final_messages = getattr(response, "messages", []) or []
|
|
256
|
+
last_content = final_messages[-1].get("content", "") if final_messages else ""
|
|
257
|
+
|
|
258
|
+
_validate(
|
|
259
|
+
self._agent_id, self._base_url, self._api_key, self._user,
|
|
260
|
+
str(last_user)[:500], {"content": last_content[:2000]},
|
|
261
|
+
self._framework, self._block_on_block, latency_ms,
|
|
262
|
+
)
|
|
263
|
+
return response
|
agentrust_sdk/auth.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""
|
|
2
|
+
API key resolution and tier detection.
|
|
3
|
+
|
|
4
|
+
Key resolution order (first found wins):
|
|
5
|
+
1. AGENTRUST_KEY env var
|
|
6
|
+
2. .agentrust/config.yaml (project-level)
|
|
7
|
+
3. ~/.agentrust/config.yaml (global)
|
|
8
|
+
4. Passed explicitly to client/decorator
|
|
9
|
+
|
|
10
|
+
No key found → Tier.OSS (schema validation only, no telemetry).
|
|
11
|
+
|
|
12
|
+
API keys are signed JWTs (HS256) containing:
|
|
13
|
+
{ "sub": "<org_id>", "tier": "team", "iat": ..., "exp": ... }
|
|
14
|
+
|
|
15
|
+
For local/dev keys issued by `agentrust init --local`, the key is a
|
|
16
|
+
plain opaque string and tier defaults to FREE.
|
|
17
|
+
"""
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import base64
|
|
21
|
+
import json
|
|
22
|
+
import logging
|
|
23
|
+
import os
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Any
|
|
26
|
+
|
|
27
|
+
import yaml
|
|
28
|
+
|
|
29
|
+
from .tiers import Tier
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
_ENV_VAR = "AGENTRUST_KEY"
|
|
34
|
+
_CONFIG_FILENAME = "config.yaml"
|
|
35
|
+
_CONFIG_DIR = ".agentrust"
|
|
36
|
+
_GLOBAL_CONFIG = Path.home() / _CONFIG_DIR / _CONFIG_FILENAME
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# ---------------------------------------------------------------------------
|
|
40
|
+
# Config file helpers
|
|
41
|
+
# ---------------------------------------------------------------------------
|
|
42
|
+
|
|
43
|
+
def _project_config() -> Path | None:
|
|
44
|
+
"""Walk up from cwd looking for .agentrust/config.yaml"""
|
|
45
|
+
current = Path.cwd()
|
|
46
|
+
for parent in [current, *current.parents]:
|
|
47
|
+
candidate = parent / _CONFIG_DIR / _CONFIG_FILENAME
|
|
48
|
+
if candidate.exists():
|
|
49
|
+
return candidate
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _read_config(path: Path) -> dict[str, Any]:
|
|
54
|
+
try:
|
|
55
|
+
return yaml.safe_load(path.read_text()) or {}
|
|
56
|
+
except Exception:
|
|
57
|
+
return {}
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _write_config(path: Path, data: dict[str, Any]) -> None:
|
|
61
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
62
|
+
path.write_text(yaml.dump(data, default_flow_style=False))
|
|
63
|
+
# Restrict to owner-only: API key must not be world-readable.
|
|
64
|
+
path.chmod(0o600)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# ---------------------------------------------------------------------------
|
|
68
|
+
# JWT decode (no signature verification in SDK — server verifies)
|
|
69
|
+
# ---------------------------------------------------------------------------
|
|
70
|
+
|
|
71
|
+
def _decode_jwt_payload(token: str) -> dict[str, Any]:
|
|
72
|
+
"""
|
|
73
|
+
Decode JWT payload without verifying signature.
|
|
74
|
+
The control plane verifies; the SDK only reads the tier claim.
|
|
75
|
+
"""
|
|
76
|
+
try:
|
|
77
|
+
parts = token.split(".")
|
|
78
|
+
if len(parts) != 3:
|
|
79
|
+
return {}
|
|
80
|
+
# Add padding if needed
|
|
81
|
+
payload = parts[1]
|
|
82
|
+
payload += "=" * (4 - len(payload) % 4)
|
|
83
|
+
decoded = base64.urlsafe_b64decode(payload)
|
|
84
|
+
return json.loads(decoded)
|
|
85
|
+
except Exception:
|
|
86
|
+
return {}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# ---------------------------------------------------------------------------
|
|
90
|
+
# Public API
|
|
91
|
+
# ---------------------------------------------------------------------------
|
|
92
|
+
|
|
93
|
+
class KeyInfo:
|
|
94
|
+
__slots__ = ("key", "tier", "org_id", "source")
|
|
95
|
+
|
|
96
|
+
def __init__(self, key: str | None, tier: Tier, org_id: str, source: str) -> None:
|
|
97
|
+
self.key = key
|
|
98
|
+
self.tier = tier
|
|
99
|
+
self.org_id = org_id
|
|
100
|
+
self.source = source # "env", "project_config", "global_config", "explicit", "none"
|
|
101
|
+
|
|
102
|
+
def __repr__(self) -> str:
|
|
103
|
+
return f"KeyInfo(tier={self.tier.value}, org={self.org_id}, source={self.source})"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
_OSS_KEY_INFO = KeyInfo(key=None, tier=Tier.OSS, org_id="anonymous", source="none")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def resolve_key(explicit_key: str | None = None) -> KeyInfo:
|
|
110
|
+
"""
|
|
111
|
+
Resolve API key and decode tier. Never raises — falls back to OSS on any error.
|
|
112
|
+
"""
|
|
113
|
+
candidates: list[tuple[str, str]] = [] # (key, source)
|
|
114
|
+
|
|
115
|
+
if explicit_key:
|
|
116
|
+
candidates.append((explicit_key, "explicit"))
|
|
117
|
+
|
|
118
|
+
env_key = os.environ.get(_ENV_VAR, "").strip()
|
|
119
|
+
if env_key:
|
|
120
|
+
candidates.append((env_key, "env"))
|
|
121
|
+
|
|
122
|
+
project_cfg_path = _project_config()
|
|
123
|
+
if project_cfg_path:
|
|
124
|
+
cfg = _read_config(project_cfg_path)
|
|
125
|
+
if cfg.get("api_key"):
|
|
126
|
+
candidates.append((cfg["api_key"], "project_config"))
|
|
127
|
+
|
|
128
|
+
if _GLOBAL_CONFIG.exists():
|
|
129
|
+
cfg = _read_config(_GLOBAL_CONFIG)
|
|
130
|
+
if cfg.get("api_key"):
|
|
131
|
+
candidates.append((cfg["api_key"], "global_config"))
|
|
132
|
+
|
|
133
|
+
for key, source in candidates:
|
|
134
|
+
info = _parse_key(key, source)
|
|
135
|
+
if info is not None:
|
|
136
|
+
return info
|
|
137
|
+
|
|
138
|
+
return _OSS_KEY_INFO
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _parse_key(key: str, source: str) -> KeyInfo | None:
|
|
142
|
+
if not key:
|
|
143
|
+
return None
|
|
144
|
+
|
|
145
|
+
# JWT keys have 2 dots.
|
|
146
|
+
# The payload tier is used for client-side capability hints (UX only — early errors,
|
|
147
|
+
# upgrade prompts). The gateway verifies the JWT signature server-side before
|
|
148
|
+
# granting any access, so a forged tier in the payload is harmless here.
|
|
149
|
+
if key.count(".") == 2:
|
|
150
|
+
payload = _decode_jwt_payload(key)
|
|
151
|
+
org_id = payload.get("sub", "unknown")
|
|
152
|
+
tier_str = payload.get("tier", "free")
|
|
153
|
+
try:
|
|
154
|
+
tier = Tier(tier_str)
|
|
155
|
+
except ValueError:
|
|
156
|
+
tier = Tier.FREE
|
|
157
|
+
return KeyInfo(key=key, tier=tier, org_id=org_id, source=source)
|
|
158
|
+
|
|
159
|
+
# Plain opaque key → FREE tier (issued by agentrust init --local)
|
|
160
|
+
return KeyInfo(key=key, tier=Tier.FREE, org_id="local", source=source)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
# ---------------------------------------------------------------------------
|
|
164
|
+
# Config file management (used by CLI)
|
|
165
|
+
# ---------------------------------------------------------------------------
|
|
166
|
+
|
|
167
|
+
def save_key_to_config(api_key: str, scope: str = "global") -> Path:
|
|
168
|
+
"""Write API key to config file. scope='global' or 'project'."""
|
|
169
|
+
if scope == "project":
|
|
170
|
+
path = Path.cwd() / _CONFIG_DIR / _CONFIG_FILENAME
|
|
171
|
+
else:
|
|
172
|
+
path = _GLOBAL_CONFIG
|
|
173
|
+
|
|
174
|
+
existing = _read_config(path) if path.exists() else {}
|
|
175
|
+
existing["api_key"] = api_key
|
|
176
|
+
|
|
177
|
+
# Decode and store tier for convenience (humans reading the file)
|
|
178
|
+
info = _parse_key(api_key, "save")
|
|
179
|
+
if info:
|
|
180
|
+
existing["tier"] = info.tier.value
|
|
181
|
+
existing["org_id"] = info.org_id
|
|
182
|
+
|
|
183
|
+
_write_config(path, existing)
|
|
184
|
+
logger.info("API key saved to %s", path)
|
|
185
|
+
return path
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def read_config(scope: str = "global") -> dict[str, Any]:
|
|
189
|
+
if scope == "project":
|
|
190
|
+
p = _project_config()
|
|
191
|
+
return _read_config(p) if p else {}
|
|
192
|
+
return _read_config(_GLOBAL_CONFIG) if _GLOBAL_CONFIG.exists() else {}
|