agentauthlayer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
agent_auth/registry.py ADDED
@@ -0,0 +1,90 @@
1
+ """agent_auth.registry — central tool registry mapping tools → required scopes."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ from agent_auth.exceptions import AgentAuthError
8
+
9
+
10
+ class ToolNotRegisteredError(AgentAuthError):
11
+ """Raised when a tool name is not found in the registry."""
12
+
13
+ def __init__(self, tool_name: str) -> None:
14
+ self.tool_name = tool_name
15
+ super().__init__(f"Tool not registered: {tool_name}")
16
+
17
+
18
+ class DuplicateToolError(AgentAuthError):
19
+ """Raised when registering a tool name that already exists."""
20
+
21
+ def __init__(self, tool_name: str) -> None:
22
+ self.tool_name = tool_name
23
+ super().__init__(f"Tool already registered: {tool_name}")
24
+
25
+
26
+ @dataclass(frozen=True, slots=True)
27
+ class ToolPolicy:
28
+ """Policy for a single tool."""
29
+
30
+ name: str
31
+ required_scope: str
32
+ description: str = ""
33
+
34
+
35
+ class ToolRegistry:
36
+ """Central registry: tool_name → required scope.
37
+
38
+ Usage::
39
+
40
+ registry = ToolRegistry()
41
+ registry.register("send_email", scope="email:send")
42
+ registry.register("create_ticket", scope="ticket:create")
43
+
44
+ policy = registry.get("send_email")
45
+ # policy.required_scope == "email:send"
46
+ """
47
+
48
+ def __init__(self) -> None:
49
+ self._tools: dict[str, ToolPolicy] = {}
50
+
51
+ def register(
52
+ self,
53
+ name: str,
54
+ scope: str,
55
+ description: str = "",
56
+ *,
57
+ allow_update: bool = False,
58
+ ) -> ToolPolicy:
59
+ """Register a tool with its required scope."""
60
+ if name in self._tools and not allow_update:
61
+ raise DuplicateToolError(name)
62
+ policy = ToolPolicy(name=name, required_scope=scope, description=description)
63
+ self._tools[name] = policy
64
+ return policy
65
+
66
+ def get(self, name: str) -> ToolPolicy:
67
+ """Get tool policy or raise ``ToolNotRegisteredError``."""
68
+ try:
69
+ return self._tools[name]
70
+ except KeyError:
71
+ raise ToolNotRegisteredError(name)
72
+
73
+ def list_tools(self) -> list[ToolPolicy]:
74
+ """Return all registered tool policies."""
75
+ return list(self._tools.values())
76
+
77
+ def tools_for_scopes(self, scopes: list[str]) -> list[ToolPolicy]:
78
+ """Return tools that the given scopes grant access to.
79
+
80
+ Supports wildcard ``"*"`` scope (grants all tools).
81
+ """
82
+ if "*" in scopes:
83
+ return self.list_tools()
84
+ return [t for t in self._tools.values() if t.required_scope in scopes]
85
+
86
+ def __contains__(self, name: str) -> bool:
87
+ return name in self._tools
88
+
89
+ def __len__(self) -> int:
90
+ return len(self._tools)
agent_auth/session.py ADDED
@@ -0,0 +1,135 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from datetime import datetime, timedelta, timezone
5
+ from typing import Protocol
6
+ from uuid import uuid4
7
+
8
+ from jose import JWTError, jwt
9
+
10
+ from agent_auth.models import TokenRecord
11
+ from agent_auth.principals import Principal
12
+
13
+
14
+ class TokenRecordRepository(Protocol):
15
+ def save(self, record: TokenRecord) -> TokenRecord:
16
+ ...
17
+
18
+ def get(self, jti: str) -> TokenRecord | None:
19
+ ...
20
+
21
+ def revoke_all_for_subject(self, subject_id: str) -> int:
22
+ ...
23
+
24
+
25
+ @dataclass(frozen=True, slots=True)
26
+ class TokenSettings:
27
+ secret_key: str
28
+ algorithm: str = "HS256"
29
+ access_token_expire_minutes: int = 60
30
+ refresh_token_expire_days: int = 7
31
+
32
+
33
+ class PrincipalTokenService:
34
+ """Core token/session lifecycle for user, agent, and system principals."""
35
+
36
+ def __init__(self, settings: TokenSettings, repo: TokenRecordRepository):
37
+ self.settings = settings
38
+ self.repo = repo
39
+
40
+ def issue_access_token(self, principal: Principal) -> tuple[str, TokenRecord]:
41
+ return self._issue_token(principal, expires_delta=timedelta(minutes=self.settings.access_token_expire_minutes))
42
+
43
+ def issue_refresh_token(self, principal: Principal) -> tuple[str, TokenRecord]:
44
+ return self._issue_token(
45
+ principal,
46
+ expires_delta=timedelta(days=self.settings.refresh_token_expire_days),
47
+ token_type="refresh",
48
+ )
49
+
50
+ def rotate_refresh_token(self, raw_refresh: str, principal: Principal) -> tuple[str, TokenRecord, str, TokenRecord] | None:
51
+ payload = self.decode_token(raw_refresh)
52
+ if payload is None or payload.get("type") != "refresh":
53
+ return None
54
+ old_jti = payload.get("jti")
55
+ if not old_jti or self.introspect(old_jti) is None:
56
+ return None
57
+ self.revoke_token(old_jti)
58
+ access_token, access_record = self.issue_access_token(principal)
59
+ refresh_token, refresh_record = self.issue_refresh_token(principal)
60
+ return access_token, access_record, refresh_token, refresh_record
61
+
62
+ def revoke_token(self, jti: str) -> TokenRecord | None:
63
+ record = self.repo.get(jti)
64
+ if record:
65
+ record.status = "revoked"
66
+ self.repo.save(record)
67
+ return record
68
+
69
+ def revoke_all_for_subject(self, subject_id: str) -> int:
70
+ return self.repo.revoke_all_for_subject(subject_id)
71
+
72
+ def introspect(self, jti: str) -> TokenRecord | None:
73
+ record = self.repo.get(jti)
74
+ if not record or record.status != "active":
75
+ return None
76
+ expires_at = record.expires_at
77
+ if expires_at.tzinfo is None:
78
+ expires_at = expires_at.replace(tzinfo=timezone.utc)
79
+ if expires_at <= datetime.now(timezone.utc):
80
+ return None
81
+ return record
82
+
83
+ def introspect_by_token(self, raw_token: str) -> TokenRecord | None:
84
+ payload = self.decode_token(raw_token)
85
+ if payload is None:
86
+ return None
87
+ jti = payload.get("jti")
88
+ if not jti:
89
+ return None
90
+ return self.introspect(jti)
91
+
92
+ def authenticate_bearer(self, raw_token: str) -> tuple[dict, TokenRecord] | None:
93
+ payload = self.decode_token(raw_token)
94
+ if payload is None:
95
+ return None
96
+ principal_id = payload.get("sub")
97
+ jti = payload.get("jti")
98
+ if not principal_id or not jti:
99
+ return None
100
+ record = self.introspect(jti)
101
+ if record is None:
102
+ return None
103
+ return payload, record
104
+
105
+ def decode_token(self, raw_token: str) -> dict | None:
106
+ try:
107
+ return jwt.decode(raw_token, self.settings.secret_key, algorithms=[self.settings.algorithm])
108
+ except JWTError:
109
+ return None
110
+
111
+ def _issue_token(self, principal: Principal, *, expires_delta: timedelta, token_type: str | None = None) -> tuple[str, TokenRecord]:
112
+ now = datetime.now(timezone.utc)
113
+ expires_at = now + expires_delta
114
+ jti = str(uuid4())
115
+ payload = {
116
+ "sub": principal.principal_id,
117
+ "sub_type": principal.principal_type,
118
+ "jti": jti,
119
+ "scopes": list(principal.scopes),
120
+ "role": principal.role,
121
+ "email": principal.email,
122
+ "exp": expires_at,
123
+ }
124
+ if token_type:
125
+ payload["type"] = token_type
126
+ token = jwt.encode(payload, self.settings.secret_key, algorithm=self.settings.algorithm)
127
+ record = TokenRecord(
128
+ jti=jti,
129
+ principal_id=principal.principal_id,
130
+ principal_type=principal.principal_type,
131
+ scopes=list(principal.scopes),
132
+ expires_at=expires_at,
133
+ )
134
+ self.repo.save(record)
135
+ return token, record