agentauthlayer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_auth/__init__.py +48 -0
- agent_auth/__main__.py +4 -0
- agent_auth/agents.py +40 -0
- agent_auth/audit.py +51 -0
- agent_auth/auth.py +36 -0
- agent_auth/cli.py +28 -0
- agent_auth/client.py +107 -0
- agent_auth/context.py +21 -0
- agent_auth/core.py +638 -0
- agent_auth/delegation.py +15 -0
- agent_auth/exceptions.py +72 -0
- agent_auth/models.py +72 -0
- agent_auth/policy.py +296 -0
- agent_auth/policy_service.py +176 -0
- agent_auth/principals.py +44 -0
- agent_auth/registry.py +90 -0
- agent_auth/session.py +135 -0
- agent_auth/storage.py +536 -0
- agent_auth/tokens.py +92 -0
- agent_auth/users.py +173 -0
- agentauthlayer-0.1.0.dist-info/METADATA +131 -0
- agentauthlayer-0.1.0.dist-info/RECORD +24 -0
- agentauthlayer-0.1.0.dist-info/WHEEL +5 -0
- agentauthlayer-0.1.0.dist-info/top_level.txt +1 -0
agent_auth/core.py
ADDED
|
@@ -0,0 +1,638 @@
|
|
|
1
|
+
"""agent_auth.core — the main ``AgentAuth`` class."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import secrets
|
|
7
|
+
from datetime import datetime, timedelta, timezone
|
|
8
|
+
from functools import wraps
|
|
9
|
+
from typing import Any, Callable
|
|
10
|
+
from uuid import uuid4
|
|
11
|
+
|
|
12
|
+
from jose import ExpiredSignatureError, JWTError, jwt
|
|
13
|
+
|
|
14
|
+
from agent_auth.audit import AuditBackend, AuditLogger
|
|
15
|
+
from agent_auth.exceptions import (
|
|
16
|
+
AgentNotFoundError,
|
|
17
|
+
DuplicateAgentError,
|
|
18
|
+
DuplicateUserError,
|
|
19
|
+
InvalidCredentialsError,
|
|
20
|
+
ScopeError,
|
|
21
|
+
TokenExpiredError,
|
|
22
|
+
TokenInvalidError,
|
|
23
|
+
TokenMissingError,
|
|
24
|
+
TokenRevokedError,
|
|
25
|
+
TokenTypeMismatchError,
|
|
26
|
+
UserNotFoundError,
|
|
27
|
+
)
|
|
28
|
+
from agent_auth.models import Agent, TokenClaims, User, UserClaims
|
|
29
|
+
from agent_auth.policy import check_scope
|
|
30
|
+
from agent_auth.registry import ToolNotRegisteredError, ToolPolicy, ToolRegistry
|
|
31
|
+
from agent_auth.storage import (
|
|
32
|
+
AgentStore,
|
|
33
|
+
AuditStore,
|
|
34
|
+
TokenStore,
|
|
35
|
+
ToolStore,
|
|
36
|
+
UserStore,
|
|
37
|
+
create_db_engine,
|
|
38
|
+
)
|
|
39
|
+
from agent_auth.tokens import issue_refresh_token, issue_token, verify_token
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class AgentAuth:
|
|
43
|
+
"""Main entry-point for the agent_auth library.
|
|
44
|
+
|
|
45
|
+
Handles agents, human users, JWT tokens (access + refresh),
|
|
46
|
+
tool registry, scope enforcement, and audit logging.
|
|
47
|
+
|
|
48
|
+
Usage::
|
|
49
|
+
|
|
50
|
+
auth = AgentAuth(secret_key="your-secret")
|
|
51
|
+
|
|
52
|
+
# --- agents ---
|
|
53
|
+
auth.register_agent("email-bot", scopes=["send_email"], owner="team-a")
|
|
54
|
+
token = auth.issue_token("email-bot")
|
|
55
|
+
|
|
56
|
+
@auth.require_scope("send_email")
|
|
57
|
+
def send_email(to, body, **ctx): ...
|
|
58
|
+
|
|
59
|
+
# --- users ---
|
|
60
|
+
user = auth.register_user("alice@co.com", "s3cret")
|
|
61
|
+
user_token = auth.issue_user_token(user)
|
|
62
|
+
claims = auth.verify_user_token(user_token)
|
|
63
|
+
|
|
64
|
+
# --- refresh tokens ---
|
|
65
|
+
refresh = auth.issue_refresh_token("email-bot")
|
|
66
|
+
new_access, new_refresh = auth.rotate_refresh_token(refresh)
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
secret_key: str,
|
|
72
|
+
algorithm: str = "HS256",
|
|
73
|
+
token_expire_minutes: int = 30,
|
|
74
|
+
refresh_token_expire_days: int = 7,
|
|
75
|
+
audit_backend: AuditBackend | None = None,
|
|
76
|
+
database_url: str | None = None,
|
|
77
|
+
store: AgentStore | None = None,
|
|
78
|
+
) -> None:
|
|
79
|
+
self._secret = secret_key
|
|
80
|
+
self._algorithm = algorithm
|
|
81
|
+
self._expire_minutes = token_expire_minutes
|
|
82
|
+
self._refresh_expire_days = refresh_token_expire_days
|
|
83
|
+
|
|
84
|
+
self._store = store
|
|
85
|
+
if self._store is None:
|
|
86
|
+
engine = create_db_engine(database_url)
|
|
87
|
+
self._store = AgentStore(engine)
|
|
88
|
+
self._store.init()
|
|
89
|
+
|
|
90
|
+
_engine = self._store.engine
|
|
91
|
+
self._token_store = TokenStore(_engine)
|
|
92
|
+
self._tool_store = ToolStore(_engine)
|
|
93
|
+
self._audit_store = AuditStore(_engine)
|
|
94
|
+
self._user_store = UserStore(_engine)
|
|
95
|
+
|
|
96
|
+
self._agents: dict[str, Agent] = {}
|
|
97
|
+
self._revoked_jtis: set[str] = set()
|
|
98
|
+
self._tool_registry = ToolRegistry()
|
|
99
|
+
self._audit = AuditLogger(audit_backend)
|
|
100
|
+
|
|
101
|
+
# ------------------------------------------------------------------
|
|
102
|
+
# Agent registry
|
|
103
|
+
# ------------------------------------------------------------------
|
|
104
|
+
|
|
105
|
+
def register_agent(
|
|
106
|
+
self,
|
|
107
|
+
agent_id: str,
|
|
108
|
+
scopes: list[str] | None = None,
|
|
109
|
+
metadata: dict | None = None,
|
|
110
|
+
name: str | None = None,
|
|
111
|
+
owner: str | None = None,
|
|
112
|
+
role: str | None = None,
|
|
113
|
+
*,
|
|
114
|
+
allow_update: bool = False,
|
|
115
|
+
) -> Agent:
|
|
116
|
+
"""Register an agent identity.
|
|
117
|
+
|
|
118
|
+
Raises ``DuplicateAgentError`` if *agent_id* already exists
|
|
119
|
+
and *allow_update* is False.
|
|
120
|
+
"""
|
|
121
|
+
existing = self._store.get_agent(agent_id)
|
|
122
|
+
if existing is not None and not allow_update:
|
|
123
|
+
raise DuplicateAgentError(agent_id)
|
|
124
|
+
|
|
125
|
+
agent = self._store.upsert_agent(
|
|
126
|
+
agent_id,
|
|
127
|
+
scopes or [],
|
|
128
|
+
name=name,
|
|
129
|
+
owner=owner,
|
|
130
|
+
role=role,
|
|
131
|
+
metadata=metadata,
|
|
132
|
+
)
|
|
133
|
+
self._agents[agent_id] = agent
|
|
134
|
+
self._audit.record("agent_registered", agent_id=agent_id, scopes=agent.scopes)
|
|
135
|
+
return agent
|
|
136
|
+
|
|
137
|
+
def get_agent(self, agent_id: str) -> Agent:
|
|
138
|
+
"""Fetch a registered agent or raise ``AgentNotFoundError``."""
|
|
139
|
+
if agent_id in self._agents:
|
|
140
|
+
return self._agents[agent_id]
|
|
141
|
+
|
|
142
|
+
agent = self._store.get_agent(agent_id)
|
|
143
|
+
if agent is None:
|
|
144
|
+
raise AgentNotFoundError(agent_id)
|
|
145
|
+
self._agents[agent_id] = agent
|
|
146
|
+
return agent
|
|
147
|
+
|
|
148
|
+
def list_agents(self) -> list[Agent]:
|
|
149
|
+
"""Return all registered agents."""
|
|
150
|
+
return self._store.list_agents()
|
|
151
|
+
|
|
152
|
+
def delete_agent(self, agent_id: str) -> bool:
|
|
153
|
+
"""Delete an agent. Returns True if deleted, False if it didn't exist."""
|
|
154
|
+
deleted = self._store.delete_agent(agent_id)
|
|
155
|
+
if agent_id in self._agents:
|
|
156
|
+
del self._agents[agent_id]
|
|
157
|
+
if deleted:
|
|
158
|
+
self._audit.record("agent_deleted", agent_id=agent_id)
|
|
159
|
+
self._audit_store.log("agent_deleted", agent_id=agent_id)
|
|
160
|
+
return deleted
|
|
161
|
+
|
|
162
|
+
# ------------------------------------------------------------------
|
|
163
|
+
# Access token lifecycle
|
|
164
|
+
# ------------------------------------------------------------------
|
|
165
|
+
|
|
166
|
+
def issue_token(self, agent_id: str) -> str:
|
|
167
|
+
"""Issue a signed access JWT for a registered agent.
|
|
168
|
+
|
|
169
|
+
Returns the encoded JWT string.
|
|
170
|
+
"""
|
|
171
|
+
agent = self.get_agent(agent_id)
|
|
172
|
+
token, jti = issue_token(
|
|
173
|
+
agent_id=agent.agent_id,
|
|
174
|
+
scopes=agent.scopes,
|
|
175
|
+
secret_key=self._secret,
|
|
176
|
+
algorithm=self._algorithm,
|
|
177
|
+
expire_minutes=self._expire_minutes,
|
|
178
|
+
)
|
|
179
|
+
claims = verify_token(token, self._secret, self._algorithm)
|
|
180
|
+
self._token_store.save(
|
|
181
|
+
jti=jti,
|
|
182
|
+
agent_id=agent_id,
|
|
183
|
+
scopes=agent.scopes,
|
|
184
|
+
expires_at=claims.expires_at,
|
|
185
|
+
token_type="access",
|
|
186
|
+
)
|
|
187
|
+
self._audit.record("token_issued", agent_id=agent_id, jti=jti)
|
|
188
|
+
self._audit_store.log("token_issued", agent_id=agent_id, details={"jti": jti})
|
|
189
|
+
return token
|
|
190
|
+
|
|
191
|
+
def verify(self, token: str) -> TokenClaims:
|
|
192
|
+
"""Verify an access JWT and return decoded claims.
|
|
193
|
+
|
|
194
|
+
Raises on expiry, bad signature, wrong type, or revoked token.
|
|
195
|
+
"""
|
|
196
|
+
claims = verify_token(token, self._secret, self._algorithm)
|
|
197
|
+
if claims.token_type != "access":
|
|
198
|
+
raise TokenTypeMismatchError("access", claims.token_type)
|
|
199
|
+
if claims.jti in self._revoked_jtis or self._token_store.is_revoked(claims.jti):
|
|
200
|
+
raise TokenRevokedError(claims.jti)
|
|
201
|
+
return claims
|
|
202
|
+
|
|
203
|
+
def revoke_token(self, token: str) -> str:
|
|
204
|
+
"""Revoke a token by its encoded JWT. Returns the revoked jti."""
|
|
205
|
+
claims = verify_token(token, self._secret, self._algorithm)
|
|
206
|
+
self._revoked_jtis.add(claims.jti)
|
|
207
|
+
self._token_store.revoke(claims.jti)
|
|
208
|
+
self._audit.record("token_revoked", agent_id=claims.agent_id, jti=claims.jti)
|
|
209
|
+
self._audit_store.log("token_revoked", agent_id=claims.agent_id,
|
|
210
|
+
details={"jti": claims.jti})
|
|
211
|
+
return claims.jti
|
|
212
|
+
|
|
213
|
+
def revoke_jti(self, jti: str) -> None:
|
|
214
|
+
"""Revoke a token directly by its jti."""
|
|
215
|
+
self._revoked_jtis.add(jti)
|
|
216
|
+
self._token_store.revoke(jti)
|
|
217
|
+
self._audit.record("token_revoked_by_jti", jti=jti)
|
|
218
|
+
self._audit_store.log("token_revoked_by_jti", details={"jti": jti})
|
|
219
|
+
|
|
220
|
+
def introspect_token(self, jti: str) -> TokenClaims | None:
|
|
221
|
+
"""Return claims for an active non-expired token by its jti, or None.
|
|
222
|
+
|
|
223
|
+
Works for both access and refresh tokens.
|
|
224
|
+
"""
|
|
225
|
+
row = self._token_store.get(jti)
|
|
226
|
+
if row is None or row.status != "active":
|
|
227
|
+
return None
|
|
228
|
+
expires = row.expires_at
|
|
229
|
+
if expires.tzinfo is None:
|
|
230
|
+
expires = expires.replace(tzinfo=timezone.utc)
|
|
231
|
+
if expires <= datetime.now(timezone.utc):
|
|
232
|
+
return None
|
|
233
|
+
scopes = json.loads(row.scopes_json) if row.scopes_json else []
|
|
234
|
+
return TokenClaims(
|
|
235
|
+
agent_id=row.agent_id,
|
|
236
|
+
jti=jti,
|
|
237
|
+
scopes=scopes,
|
|
238
|
+
issued_at=row.issued_at,
|
|
239
|
+
expires_at=expires,
|
|
240
|
+
token_type=row.token_type,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# ------------------------------------------------------------------
|
|
244
|
+
# Refresh token lifecycle
|
|
245
|
+
# ------------------------------------------------------------------
|
|
246
|
+
|
|
247
|
+
def issue_refresh_token(self, agent_id: str) -> str:
|
|
248
|
+
"""Issue a signed refresh JWT for a registered agent.
|
|
249
|
+
|
|
250
|
+
Returns the encoded JWT string.
|
|
251
|
+
"""
|
|
252
|
+
agent = self.get_agent(agent_id)
|
|
253
|
+
token, jti = issue_refresh_token(
|
|
254
|
+
agent_id=agent.agent_id,
|
|
255
|
+
scopes=agent.scopes,
|
|
256
|
+
secret_key=self._secret,
|
|
257
|
+
algorithm=self._algorithm,
|
|
258
|
+
expire_days=self._refresh_expire_days,
|
|
259
|
+
)
|
|
260
|
+
now = datetime.now(timezone.utc)
|
|
261
|
+
expires_at = now + timedelta(days=self._refresh_expire_days)
|
|
262
|
+
self._token_store.save(
|
|
263
|
+
jti=jti,
|
|
264
|
+
agent_id=agent_id,
|
|
265
|
+
scopes=agent.scopes,
|
|
266
|
+
expires_at=expires_at,
|
|
267
|
+
token_type="refresh",
|
|
268
|
+
)
|
|
269
|
+
self._audit.record("refresh_token_issued", agent_id=agent_id, jti=jti)
|
|
270
|
+
self._audit_store.log("refresh_token_issued", agent_id=agent_id,
|
|
271
|
+
details={"jti": jti})
|
|
272
|
+
return token
|
|
273
|
+
|
|
274
|
+
def verify_refresh_token(self, token: str) -> TokenClaims:
|
|
275
|
+
"""Verify a refresh JWT and return decoded claims.
|
|
276
|
+
|
|
277
|
+
Raises ``TokenTypeMismatchError`` if the token is not a refresh token.
|
|
278
|
+
Raises on expiry, bad signature, or revoked token.
|
|
279
|
+
"""
|
|
280
|
+
claims = verify_token(token, self._secret, self._algorithm)
|
|
281
|
+
if claims.token_type != "refresh":
|
|
282
|
+
raise TokenTypeMismatchError("refresh", claims.token_type)
|
|
283
|
+
if claims.jti in self._revoked_jtis or self._token_store.is_revoked(claims.jti):
|
|
284
|
+
raise TokenRevokedError(claims.jti)
|
|
285
|
+
return claims
|
|
286
|
+
|
|
287
|
+
def rotate_refresh_token(self, refresh_token: str) -> tuple[str, str]:
|
|
288
|
+
"""Validate the refresh token, revoke it, and issue a new access + refresh pair.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
(new_access_token, new_refresh_token)
|
|
292
|
+
|
|
293
|
+
The old refresh token is revoked (one-time use). Raises on invalid or
|
|
294
|
+
already-revoked refresh token.
|
|
295
|
+
"""
|
|
296
|
+
claims = self.verify_refresh_token(refresh_token)
|
|
297
|
+
self._revoked_jtis.add(claims.jti)
|
|
298
|
+
self._token_store.revoke(claims.jti)
|
|
299
|
+
new_access = self.issue_token(claims.agent_id)
|
|
300
|
+
new_refresh = self.issue_refresh_token(claims.agent_id)
|
|
301
|
+
self._audit.record(
|
|
302
|
+
"refresh_token_rotated",
|
|
303
|
+
agent_id=claims.agent_id,
|
|
304
|
+
old_jti=claims.jti,
|
|
305
|
+
)
|
|
306
|
+
return new_access, new_refresh
|
|
307
|
+
|
|
308
|
+
# ------------------------------------------------------------------
|
|
309
|
+
# User management
|
|
310
|
+
# ------------------------------------------------------------------
|
|
311
|
+
|
|
312
|
+
def register_user(
|
|
313
|
+
self,
|
|
314
|
+
email: str,
|
|
315
|
+
password: str,
|
|
316
|
+
name: str | None = None,
|
|
317
|
+
role: str = "user",
|
|
318
|
+
) -> User:
|
|
319
|
+
"""Register a new human user.
|
|
320
|
+
|
|
321
|
+
Raises ``DuplicateUserError`` if the email is already registered.
|
|
322
|
+
"""
|
|
323
|
+
if self._user_store.get_by_email(email) is not None:
|
|
324
|
+
raise DuplicateUserError(email)
|
|
325
|
+
user = self._user_store.create_user(email, password, name=name, role=role)
|
|
326
|
+
self._audit.record("user_registered", user_id=user.id, email=user.email)
|
|
327
|
+
self._audit_store.log("user_registered", details={"user_id": user.id, "email": user.email})
|
|
328
|
+
return user
|
|
329
|
+
|
|
330
|
+
def authenticate_user(self, email: str, password: str) -> User:
|
|
331
|
+
"""Verify email + password.
|
|
332
|
+
|
|
333
|
+
Returns the ``User`` on success.
|
|
334
|
+
Raises ``InvalidCredentialsError`` on failure.
|
|
335
|
+
"""
|
|
336
|
+
user = self._user_store.authenticate(email, password)
|
|
337
|
+
if user is None:
|
|
338
|
+
raise InvalidCredentialsError()
|
|
339
|
+
self._audit.record("user_authenticated", user_id=user.id)
|
|
340
|
+
return user
|
|
341
|
+
|
|
342
|
+
def get_user(self, user_id: str) -> User:
|
|
343
|
+
"""Fetch a user by id or raise ``UserNotFoundError``."""
|
|
344
|
+
user = self._user_store.get_by_id(user_id)
|
|
345
|
+
if user is None:
|
|
346
|
+
raise UserNotFoundError(user_id)
|
|
347
|
+
return user
|
|
348
|
+
|
|
349
|
+
def get_user_by_email(self, email: str) -> User:
|
|
350
|
+
"""Fetch a user by email or raise ``UserNotFoundError``."""
|
|
351
|
+
user = self._user_store.get_by_email(email)
|
|
352
|
+
if user is None:
|
|
353
|
+
raise UserNotFoundError(email)
|
|
354
|
+
return user
|
|
355
|
+
|
|
356
|
+
def issue_user_token(self, user: User) -> str:
|
|
357
|
+
"""Issue a signed JWT for a human user.
|
|
358
|
+
|
|
359
|
+
The token carries ``type: "user"`` so it cannot be used as an agent
|
|
360
|
+
bearer token.
|
|
361
|
+
|
|
362
|
+
Returns the encoded JWT string.
|
|
363
|
+
"""
|
|
364
|
+
now = datetime.now(timezone.utc)
|
|
365
|
+
exp = now + timedelta(minutes=self._expire_minutes)
|
|
366
|
+
jti = str(uuid4())
|
|
367
|
+
payload = {
|
|
368
|
+
"sub": user.id,
|
|
369
|
+
"email": user.email,
|
|
370
|
+
"role": user.role,
|
|
371
|
+
"type": "user",
|
|
372
|
+
"jti": jti,
|
|
373
|
+
"iat": now,
|
|
374
|
+
"exp": exp,
|
|
375
|
+
}
|
|
376
|
+
token = jwt.encode(payload, self._secret, algorithm=self._algorithm)
|
|
377
|
+
self._audit.record("user_token_issued", user_id=user.id, jti=jti)
|
|
378
|
+
return token
|
|
379
|
+
|
|
380
|
+
def verify_user_token(self, token: str) -> UserClaims:
|
|
381
|
+
"""Verify a user JWT and return decoded claims.
|
|
382
|
+
|
|
383
|
+
Raises ``TokenTypeMismatchError`` if this is not a user token.
|
|
384
|
+
Raises on expiry or bad signature.
|
|
385
|
+
"""
|
|
386
|
+
try:
|
|
387
|
+
payload = jwt.decode(token, self._secret, algorithms=[self._algorithm])
|
|
388
|
+
except ExpiredSignatureError:
|
|
389
|
+
raise TokenExpiredError()
|
|
390
|
+
except JWTError as exc:
|
|
391
|
+
raise TokenInvalidError(str(exc))
|
|
392
|
+
|
|
393
|
+
if payload.get("type") != "user":
|
|
394
|
+
raise TokenTypeMismatchError("user", payload.get("type", "unknown"))
|
|
395
|
+
|
|
396
|
+
return UserClaims(
|
|
397
|
+
user_id=payload["sub"],
|
|
398
|
+
email=payload["email"],
|
|
399
|
+
role=payload["role"],
|
|
400
|
+
jti=payload["jti"],
|
|
401
|
+
issued_at=datetime.fromtimestamp(payload["iat"], tz=timezone.utc),
|
|
402
|
+
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
|
|
403
|
+
raw=payload,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
def request_password_reset(self, email: str) -> str | None:
|
|
407
|
+
"""Generate a password reset token for the given email.
|
|
408
|
+
|
|
409
|
+
Returns the reset token string, or None if the email is not registered.
|
|
410
|
+
The caller is responsible for delivering the token (email, SMS, etc.).
|
|
411
|
+
|
|
412
|
+
Returns None silently for unknown emails to prevent enumeration.
|
|
413
|
+
"""
|
|
414
|
+
user = self._user_store.get_by_email(email)
|
|
415
|
+
if user is None:
|
|
416
|
+
return None
|
|
417
|
+
token = secrets.token_urlsafe(32)
|
|
418
|
+
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
|
419
|
+
self._user_store.save_reset_token(token, user.id, expires_at)
|
|
420
|
+
self._audit.record("password_reset_requested", user_id=user.id)
|
|
421
|
+
return token
|
|
422
|
+
|
|
423
|
+
def reset_password(self, reset_token: str, new_password: str) -> bool:
|
|
424
|
+
"""Apply a new password using a reset token.
|
|
425
|
+
|
|
426
|
+
Returns True on success, False if the token is invalid/expired/used.
|
|
427
|
+
"""
|
|
428
|
+
user_id = self._user_store.validate_reset_token(reset_token)
|
|
429
|
+
if user_id is None:
|
|
430
|
+
return False
|
|
431
|
+
self._user_store.update_password(user_id, new_password)
|
|
432
|
+
self._user_store.mark_reset_token_used(reset_token)
|
|
433
|
+
self._audit.record("password_reset_completed", user_id=user_id)
|
|
434
|
+
self._audit_store.log("password_reset_completed", details={"user_id": user_id})
|
|
435
|
+
return True
|
|
436
|
+
|
|
437
|
+
# ------------------------------------------------------------------
|
|
438
|
+
# Decorators
|
|
439
|
+
# ------------------------------------------------------------------
|
|
440
|
+
|
|
441
|
+
def require_scope(self, scope: str) -> Callable:
|
|
442
|
+
"""Decorator that secures a tool function behind a scope check.
|
|
443
|
+
|
|
444
|
+
The caller **must** pass ``token=<jwt>`` as a keyword argument.
|
|
445
|
+
On success the decorator injects ``claims`` into kwargs.
|
|
446
|
+
|
|
447
|
+
Example::
|
|
448
|
+
|
|
449
|
+
@auth.require_scope("send_email")
|
|
450
|
+
def send_email(to, body, **ctx):
|
|
451
|
+
caller = ctx["claims"].agent_id
|
|
452
|
+
...
|
|
453
|
+
|
|
454
|
+
send_email(to="a@b.com", body="hi", token=my_jwt)
|
|
455
|
+
"""
|
|
456
|
+
def decorator(func: Callable) -> Callable:
|
|
457
|
+
@wraps(func)
|
|
458
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
459
|
+
token = kwargs.pop("token", None)
|
|
460
|
+
if token is None:
|
|
461
|
+
raise TokenMissingError()
|
|
462
|
+
|
|
463
|
+
claims = self.verify(token)
|
|
464
|
+
check_scope(claims.scopes, scope)
|
|
465
|
+
|
|
466
|
+
self._audit.record(
|
|
467
|
+
"tool_call_authorized",
|
|
468
|
+
agent_id=claims.agent_id,
|
|
469
|
+
tool=func.__qualname__,
|
|
470
|
+
scope=scope,
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
kwargs["claims"] = claims
|
|
474
|
+
return func(*args, **kwargs)
|
|
475
|
+
|
|
476
|
+
return wrapper
|
|
477
|
+
|
|
478
|
+
return decorator
|
|
479
|
+
|
|
480
|
+
def require_scopes(self, *scopes: str) -> Callable:
|
|
481
|
+
"""Like ``require_scope`` but checks all listed scopes."""
|
|
482
|
+
from agent_auth.policy import check_scopes as _check_scopes
|
|
483
|
+
|
|
484
|
+
def decorator(func: Callable) -> Callable:
|
|
485
|
+
@wraps(func)
|
|
486
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
487
|
+
token = kwargs.pop("token", None)
|
|
488
|
+
if token is None:
|
|
489
|
+
raise TokenMissingError()
|
|
490
|
+
|
|
491
|
+
claims = self.verify(token)
|
|
492
|
+
_check_scopes(claims.scopes, list(scopes))
|
|
493
|
+
|
|
494
|
+
self._audit.record(
|
|
495
|
+
"tool_call_authorized",
|
|
496
|
+
agent_id=claims.agent_id,
|
|
497
|
+
tool=func.__qualname__,
|
|
498
|
+
scopes=list(scopes),
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
kwargs["claims"] = claims
|
|
502
|
+
return func(*args, **kwargs)
|
|
503
|
+
|
|
504
|
+
return wrapper
|
|
505
|
+
|
|
506
|
+
return decorator
|
|
507
|
+
|
|
508
|
+
# ------------------------------------------------------------------
|
|
509
|
+
# Tool registry
|
|
510
|
+
# ------------------------------------------------------------------
|
|
511
|
+
|
|
512
|
+
def register_tool(
|
|
513
|
+
self,
|
|
514
|
+
name: str,
|
|
515
|
+
scope: str,
|
|
516
|
+
description: str = "",
|
|
517
|
+
*,
|
|
518
|
+
allow_update: bool = False,
|
|
519
|
+
) -> ToolPolicy:
|
|
520
|
+
"""Register a tool and the scope it requires."""
|
|
521
|
+
policy = self._tool_registry.register(
|
|
522
|
+
name, scope, description, allow_update=allow_update,
|
|
523
|
+
)
|
|
524
|
+
self._tool_store.upsert(name, scope, description)
|
|
525
|
+
self._audit.record("tool_registered", tool=name, scope=scope)
|
|
526
|
+
self._audit_store.log("tool_registered", tool_name=name, details={"scope": scope})
|
|
527
|
+
return policy
|
|
528
|
+
|
|
529
|
+
def authorize_tool_call(self, token: str, tool_name: str) -> TokenClaims:
|
|
530
|
+
"""Verify *token* and check that it carries the scope required by *tool_name*.
|
|
531
|
+
|
|
532
|
+
Returns ``TokenClaims`` on success. Raises on any failure.
|
|
533
|
+
"""
|
|
534
|
+
claims = self.verify(token)
|
|
535
|
+
policy = self._tool_registry.get(tool_name)
|
|
536
|
+
check_scope(claims.scopes, policy.required_scope)
|
|
537
|
+
self._audit.record(
|
|
538
|
+
"tool_call_authorized",
|
|
539
|
+
agent_id=claims.agent_id,
|
|
540
|
+
tool=tool_name,
|
|
541
|
+
scope=policy.required_scope,
|
|
542
|
+
)
|
|
543
|
+
return claims
|
|
544
|
+
|
|
545
|
+
def list_tools(self) -> list[ToolPolicy]:
|
|
546
|
+
"""Return all registered tool policies."""
|
|
547
|
+
return self._tool_registry.list_tools()
|
|
548
|
+
|
|
549
|
+
def list_agent_tools(self, agent_id: str) -> list[ToolPolicy]:
|
|
550
|
+
"""Return tools that *agent_id* is allowed to call."""
|
|
551
|
+
agent = self.get_agent(agent_id)
|
|
552
|
+
return self._tool_registry.tools_for_scopes(agent.scopes)
|
|
553
|
+
|
|
554
|
+
def grant_tool_to_agent(self, agent_id: str, tool_name: str) -> Agent:
|
|
555
|
+
"""Grant an agent the scope required by *tool_name*."""
|
|
556
|
+
policy = self._tool_registry.get(tool_name)
|
|
557
|
+
agent = self.get_agent(agent_id)
|
|
558
|
+
stored = self._store.grant_scope(agent_id, policy.required_scope)
|
|
559
|
+
self._agents[agent_id] = stored
|
|
560
|
+
self._audit.record(
|
|
561
|
+
"tool_granted",
|
|
562
|
+
agent_id=agent_id,
|
|
563
|
+
tool=tool_name,
|
|
564
|
+
scope=policy.required_scope,
|
|
565
|
+
)
|
|
566
|
+
return stored
|
|
567
|
+
|
|
568
|
+
def revoke_tool_from_agent(self, agent_id: str, tool_name: str) -> Agent:
|
|
569
|
+
"""Remove the scope required by *tool_name* from an agent."""
|
|
570
|
+
policy = self._tool_registry.get(tool_name)
|
|
571
|
+
stored = self._store.revoke_scope(agent_id, policy.required_scope)
|
|
572
|
+
self._agents[agent_id] = stored
|
|
573
|
+
self._audit.record(
|
|
574
|
+
"tool_revoked",
|
|
575
|
+
agent_id=agent_id,
|
|
576
|
+
tool=tool_name,
|
|
577
|
+
scope=policy.required_scope,
|
|
578
|
+
)
|
|
579
|
+
return stored
|
|
580
|
+
|
|
581
|
+
def tool(self, name: str, scope: str, description: str = "") -> Callable:
|
|
582
|
+
"""Decorator that registers a tool AND enforces its scope.
|
|
583
|
+
|
|
584
|
+
Combines ``register_tool`` + ``require_scope`` in one step::
|
|
585
|
+
|
|
586
|
+
@auth.tool("send_email", scope="email:send")
|
|
587
|
+
def send_email(to, body, **ctx): ...
|
|
588
|
+
|
|
589
|
+
send_email(to="a@b.com", body="hi", token=token)
|
|
590
|
+
"""
|
|
591
|
+
self._tool_registry.register(name, scope, description, allow_update=True)
|
|
592
|
+
self._tool_store.upsert(name, scope, description)
|
|
593
|
+
|
|
594
|
+
def decorator(func: Callable) -> Callable:
|
|
595
|
+
@wraps(func)
|
|
596
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
597
|
+
token = kwargs.pop("token", None)
|
|
598
|
+
if token is None:
|
|
599
|
+
raise TokenMissingError()
|
|
600
|
+
|
|
601
|
+
claims = self.verify(token)
|
|
602
|
+
check_scope(claims.scopes, scope)
|
|
603
|
+
|
|
604
|
+
self._audit.record(
|
|
605
|
+
"tool_call_authorized",
|
|
606
|
+
agent_id=claims.agent_id,
|
|
607
|
+
tool=name,
|
|
608
|
+
scope=scope,
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
kwargs["claims"] = claims
|
|
612
|
+
return func(*args, **kwargs)
|
|
613
|
+
|
|
614
|
+
return wrapper
|
|
615
|
+
|
|
616
|
+
return decorator
|
|
617
|
+
|
|
618
|
+
def secure(self, func: Callable) -> Callable:
|
|
619
|
+
"""Decorator that requires a valid token but no specific scope."""
|
|
620
|
+
|
|
621
|
+
@wraps(func)
|
|
622
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
623
|
+
token = kwargs.pop("token", None)
|
|
624
|
+
if token is None:
|
|
625
|
+
raise TokenMissingError()
|
|
626
|
+
|
|
627
|
+
claims = self.verify(token)
|
|
628
|
+
|
|
629
|
+
self._audit.record(
|
|
630
|
+
"tool_call_authenticated",
|
|
631
|
+
agent_id=claims.agent_id,
|
|
632
|
+
tool=func.__qualname__,
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
kwargs["claims"] = claims
|
|
636
|
+
return func(*args, **kwargs)
|
|
637
|
+
|
|
638
|
+
return wrapper
|
agent_auth/delegation.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(frozen=True, slots=True)
|
|
8
|
+
class DelegationGrant:
|
|
9
|
+
grant_id: str
|
|
10
|
+
delegator_id: str
|
|
11
|
+
delegatee_id: str
|
|
12
|
+
action: str
|
|
13
|
+
resource: str | None = None
|
|
14
|
+
expires_at: datetime | None = None
|
|
15
|
+
context: dict[str, str] = field(default_factory=dict)
|