agentauthlayer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_auth/__init__.py +48 -0
- agent_auth/__main__.py +4 -0
- agent_auth/agents.py +40 -0
- agent_auth/audit.py +51 -0
- agent_auth/auth.py +36 -0
- agent_auth/cli.py +28 -0
- agent_auth/client.py +107 -0
- agent_auth/context.py +21 -0
- agent_auth/core.py +638 -0
- agent_auth/delegation.py +15 -0
- agent_auth/exceptions.py +72 -0
- agent_auth/models.py +72 -0
- agent_auth/policy.py +296 -0
- agent_auth/policy_service.py +176 -0
- agent_auth/principals.py +44 -0
- agent_auth/registry.py +90 -0
- agent_auth/session.py +135 -0
- agent_auth/storage.py +536 -0
- agent_auth/tokens.py +92 -0
- agent_auth/users.py +173 -0
- agentauthlayer-0.1.0.dist-info/METADATA +131 -0
- agentauthlayer-0.1.0.dist-info/RECORD +24 -0
- agentauthlayer-0.1.0.dist-info/WHEEL +5 -0
- agentauthlayer-0.1.0.dist-info/top_level.txt +1 -0
agent_auth/storage.py
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
1
|
+
"""agent_auth.storage — SQLAlchemy-backed persistence layer.
|
|
2
|
+
|
|
3
|
+
Tables:
|
|
4
|
+
- agents — agent identities (id, name, owner, role, scopes, status)
|
|
5
|
+
- agent_scopes — granted scopes per agent
|
|
6
|
+
- tools — registered tool catalog
|
|
7
|
+
- token_records — issued tokens (access + refresh) for revocation / introspection
|
|
8
|
+
- audit_events — immutable audit trail
|
|
9
|
+
- users — human user accounts
|
|
10
|
+
- password_reset_tokens — one-time password reset tokens
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import os
|
|
17
|
+
import uuid
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from datetime import datetime, timezone
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
import bcrypt as _bcrypt
|
|
23
|
+
from sqlalchemy import (
|
|
24
|
+
Boolean,
|
|
25
|
+
DateTime,
|
|
26
|
+
ForeignKey,
|
|
27
|
+
Integer,
|
|
28
|
+
String,
|
|
29
|
+
Text,
|
|
30
|
+
UniqueConstraint,
|
|
31
|
+
create_engine,
|
|
32
|
+
delete,
|
|
33
|
+
select,
|
|
34
|
+
)
|
|
35
|
+
from sqlalchemy.engine import Engine
|
|
36
|
+
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship
|
|
37
|
+
from sqlalchemy.pool import StaticPool
|
|
38
|
+
|
|
39
|
+
from agent_auth.models import Agent, User
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# ---------------------------------------------------------------------------
|
|
43
|
+
# ORM base + table definitions
|
|
44
|
+
# ---------------------------------------------------------------------------
|
|
45
|
+
|
|
46
|
+
class Base(DeclarativeBase):
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class AgentRow(Base):
|
|
51
|
+
__tablename__ = "agents"
|
|
52
|
+
|
|
53
|
+
agent_id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
54
|
+
name: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
55
|
+
owner: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
56
|
+
role: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
57
|
+
status: Mapped[str] = mapped_column(String(32), default="active")
|
|
58
|
+
metadata_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
59
|
+
created_at: Mapped[datetime] = mapped_column(
|
|
60
|
+
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
|
|
61
|
+
)
|
|
62
|
+
updated_at: Mapped[datetime] = mapped_column(
|
|
63
|
+
DateTime(timezone=True),
|
|
64
|
+
default=lambda: datetime.now(timezone.utc),
|
|
65
|
+
onupdate=lambda: datetime.now(timezone.utc),
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
scopes: Mapped[list["AgentScopeRow"]] = relationship(
|
|
69
|
+
back_populates="agent",
|
|
70
|
+
cascade="all, delete-orphan",
|
|
71
|
+
lazy="selectin",
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class AgentScopeRow(Base):
|
|
76
|
+
__tablename__ = "agent_scopes"
|
|
77
|
+
__table_args__ = (
|
|
78
|
+
UniqueConstraint("agent_id", "scope", name="uq_agent_scope"),
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
|
82
|
+
agent_id: Mapped[str] = mapped_column(ForeignKey("agents.agent_id", ondelete="CASCADE"))
|
|
83
|
+
scope: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
84
|
+
|
|
85
|
+
agent: Mapped[AgentRow] = relationship(back_populates="scopes")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class ToolRow(Base):
|
|
89
|
+
__tablename__ = "tools"
|
|
90
|
+
|
|
91
|
+
name: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
92
|
+
required_scope: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
93
|
+
description: Mapped[str] = mapped_column(Text, default="")
|
|
94
|
+
created_at: Mapped[datetime] = mapped_column(
|
|
95
|
+
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class TokenRecordRow(Base):
|
|
100
|
+
__tablename__ = "token_records"
|
|
101
|
+
|
|
102
|
+
jti: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
103
|
+
agent_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
104
|
+
scopes_json: Mapped[str] = mapped_column(Text, default="[]")
|
|
105
|
+
token_type: Mapped[str] = mapped_column(String(32), default="access")
|
|
106
|
+
status: Mapped[str] = mapped_column(String(32), default="active")
|
|
107
|
+
issued_at: Mapped[datetime] = mapped_column(
|
|
108
|
+
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
|
|
109
|
+
)
|
|
110
|
+
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class AuditEventRow(Base):
|
|
114
|
+
__tablename__ = "audit_events"
|
|
115
|
+
|
|
116
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
117
|
+
event_type: Mapped[str] = mapped_column(String(128), nullable=False)
|
|
118
|
+
agent_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
119
|
+
tool_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
120
|
+
details_json: Mapped[str] = mapped_column(Text, default="{}")
|
|
121
|
+
timestamp: Mapped[datetime] = mapped_column(
|
|
122
|
+
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class UserRow(Base):
|
|
127
|
+
__tablename__ = "users"
|
|
128
|
+
|
|
129
|
+
id: Mapped[str] = mapped_column(String(255), primary_key=True,
|
|
130
|
+
default=lambda: str(uuid.uuid4()))
|
|
131
|
+
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
|
132
|
+
name: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
133
|
+
hashed_password: Mapped[str] = mapped_column(Text, nullable=False)
|
|
134
|
+
role: Mapped[str] = mapped_column(String(64), nullable=False, default="user")
|
|
135
|
+
created_at: Mapped[datetime] = mapped_column(
|
|
136
|
+
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
|
|
137
|
+
)
|
|
138
|
+
updated_at: Mapped[datetime] = mapped_column(
|
|
139
|
+
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class PasswordResetRow(Base):
|
|
144
|
+
__tablename__ = "password_reset_tokens"
|
|
145
|
+
|
|
146
|
+
token: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
147
|
+
user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
148
|
+
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
|
149
|
+
used: Mapped[bool] = mapped_column(Boolean, default=False)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# ---------------------------------------------------------------------------
|
|
153
|
+
# Engine helpers
|
|
154
|
+
# ---------------------------------------------------------------------------
|
|
155
|
+
|
|
156
|
+
def default_database_url() -> str:
|
|
157
|
+
url = os.getenv("DATABASE_URL")
|
|
158
|
+
if url:
|
|
159
|
+
return url
|
|
160
|
+
return "sqlite:///agent_auth.db"
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def create_db_engine(database_url: str | None = None) -> Engine:
|
|
164
|
+
url = database_url or default_database_url()
|
|
165
|
+
|
|
166
|
+
connect_args: dict = {}
|
|
167
|
+
engine_kwargs: dict = {"future": True}
|
|
168
|
+
|
|
169
|
+
if url.startswith("sqlite"):
|
|
170
|
+
connect_args["check_same_thread"] = False
|
|
171
|
+
|
|
172
|
+
if ":memory:" in url:
|
|
173
|
+
engine_kwargs["poolclass"] = StaticPool
|
|
174
|
+
|
|
175
|
+
if "file:" in url:
|
|
176
|
+
connect_args["uri"] = True
|
|
177
|
+
|
|
178
|
+
return create_engine(url, connect_args=connect_args, **engine_kwargs)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def init_db(engine: Engine) -> None:
|
|
182
|
+
Base.metadata.create_all(engine)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
# ---------------------------------------------------------------------------
|
|
186
|
+
# AgentStore
|
|
187
|
+
# ---------------------------------------------------------------------------
|
|
188
|
+
|
|
189
|
+
@dataclass(slots=True)
|
|
190
|
+
class AgentStore:
|
|
191
|
+
"""Persistent storage for agents + scopes."""
|
|
192
|
+
|
|
193
|
+
engine: Engine
|
|
194
|
+
|
|
195
|
+
def init(self) -> None:
|
|
196
|
+
init_db(self.engine)
|
|
197
|
+
|
|
198
|
+
def _row_to_agent(self, row: AgentRow) -> Agent:
|
|
199
|
+
meta = json.loads(row.metadata_json) if row.metadata_json else {}
|
|
200
|
+
return Agent(
|
|
201
|
+
agent_id=row.agent_id,
|
|
202
|
+
scopes=[s.scope for s in row.scopes],
|
|
203
|
+
metadata=meta,
|
|
204
|
+
name=row.name,
|
|
205
|
+
owner=row.owner,
|
|
206
|
+
role=row.role,
|
|
207
|
+
status=row.status,
|
|
208
|
+
created_at=row.created_at,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
def upsert_agent(
|
|
212
|
+
self,
|
|
213
|
+
agent_id: str,
|
|
214
|
+
scopes: list[str],
|
|
215
|
+
name: str | None = None,
|
|
216
|
+
owner: str | None = None,
|
|
217
|
+
role: str | None = None,
|
|
218
|
+
metadata: dict | None = None,
|
|
219
|
+
) -> Agent:
|
|
220
|
+
with Session(self.engine) as session:
|
|
221
|
+
row = session.get(AgentRow, agent_id)
|
|
222
|
+
if row is None:
|
|
223
|
+
row = AgentRow(agent_id=agent_id)
|
|
224
|
+
session.add(row)
|
|
225
|
+
|
|
226
|
+
if name is not None:
|
|
227
|
+
row.name = name
|
|
228
|
+
if owner is not None:
|
|
229
|
+
row.owner = owner
|
|
230
|
+
if role is not None:
|
|
231
|
+
row.role = role
|
|
232
|
+
if metadata is not None:
|
|
233
|
+
row.metadata_json = json.dumps(metadata, default=str)
|
|
234
|
+
row.updated_at = datetime.now(timezone.utc)
|
|
235
|
+
session.flush()
|
|
236
|
+
|
|
237
|
+
session.execute(delete(AgentScopeRow).where(AgentScopeRow.agent_id == agent_id))
|
|
238
|
+
session.flush()
|
|
239
|
+
for scope in scopes:
|
|
240
|
+
session.add(AgentScopeRow(agent_id=agent_id, scope=scope))
|
|
241
|
+
|
|
242
|
+
session.commit()
|
|
243
|
+
session.refresh(row)
|
|
244
|
+
return self._row_to_agent(row)
|
|
245
|
+
|
|
246
|
+
def get_agent(self, agent_id: str) -> Agent | None:
|
|
247
|
+
with Session(self.engine) as session:
|
|
248
|
+
row = session.get(AgentRow, agent_id)
|
|
249
|
+
if row is None:
|
|
250
|
+
return None
|
|
251
|
+
return self._row_to_agent(row)
|
|
252
|
+
|
|
253
|
+
def list_agents(self) -> list[Agent]:
|
|
254
|
+
with Session(self.engine) as session:
|
|
255
|
+
rows = session.scalars(select(AgentRow)).all()
|
|
256
|
+
return [self._row_to_agent(r) for r in rows]
|
|
257
|
+
|
|
258
|
+
def grant_scope(self, agent_id: str, scope: str) -> Agent:
|
|
259
|
+
with Session(self.engine) as session:
|
|
260
|
+
row = session.get(AgentRow, agent_id)
|
|
261
|
+
if row is None:
|
|
262
|
+
row = AgentRow(agent_id=agent_id)
|
|
263
|
+
session.add(row)
|
|
264
|
+
session.flush()
|
|
265
|
+
|
|
266
|
+
existing = session.scalar(
|
|
267
|
+
select(AgentScopeRow).where(
|
|
268
|
+
AgentScopeRow.agent_id == agent_id,
|
|
269
|
+
AgentScopeRow.scope == scope,
|
|
270
|
+
)
|
|
271
|
+
)
|
|
272
|
+
if existing is None:
|
|
273
|
+
session.add(AgentScopeRow(agent_id=agent_id, scope=scope))
|
|
274
|
+
session.commit()
|
|
275
|
+
session.refresh(row)
|
|
276
|
+
return self._row_to_agent(row)
|
|
277
|
+
|
|
278
|
+
def revoke_scope(self, agent_id: str, scope: str) -> Agent:
|
|
279
|
+
with Session(self.engine) as session:
|
|
280
|
+
session.execute(
|
|
281
|
+
delete(AgentScopeRow).where(
|
|
282
|
+
AgentScopeRow.agent_id == agent_id,
|
|
283
|
+
AgentScopeRow.scope == scope,
|
|
284
|
+
)
|
|
285
|
+
)
|
|
286
|
+
session.commit()
|
|
287
|
+
|
|
288
|
+
row = session.get(AgentRow, agent_id)
|
|
289
|
+
if row is None:
|
|
290
|
+
return Agent(agent_id=agent_id, scopes=[], metadata={})
|
|
291
|
+
return self._row_to_agent(row)
|
|
292
|
+
|
|
293
|
+
def delete_agent(self, agent_id: str) -> bool:
|
|
294
|
+
with Session(self.engine) as session:
|
|
295
|
+
row = session.get(AgentRow, agent_id)
|
|
296
|
+
if row is None:
|
|
297
|
+
return False
|
|
298
|
+
session.delete(row)
|
|
299
|
+
session.commit()
|
|
300
|
+
return True
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
# ---------------------------------------------------------------------------
|
|
304
|
+
# ToolStore
|
|
305
|
+
# ---------------------------------------------------------------------------
|
|
306
|
+
|
|
307
|
+
@dataclass(slots=True)
|
|
308
|
+
class ToolStore:
|
|
309
|
+
"""Persistent storage for the tool catalog."""
|
|
310
|
+
|
|
311
|
+
engine: Engine
|
|
312
|
+
|
|
313
|
+
def upsert(self, name: str, required_scope: str, description: str = "") -> None:
|
|
314
|
+
with Session(self.engine) as session:
|
|
315
|
+
row = session.get(ToolRow, name)
|
|
316
|
+
if row is None:
|
|
317
|
+
row = ToolRow(name=name, required_scope=required_scope, description=description)
|
|
318
|
+
session.add(row)
|
|
319
|
+
else:
|
|
320
|
+
row.required_scope = required_scope
|
|
321
|
+
row.description = description
|
|
322
|
+
session.commit()
|
|
323
|
+
|
|
324
|
+
def get(self, name: str) -> ToolRow | None:
|
|
325
|
+
with Session(self.engine) as session:
|
|
326
|
+
return session.get(ToolRow, name)
|
|
327
|
+
|
|
328
|
+
def list_all(self) -> list[ToolRow]:
|
|
329
|
+
with Session(self.engine) as session:
|
|
330
|
+
return list(session.scalars(select(ToolRow)).all())
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
# ---------------------------------------------------------------------------
|
|
334
|
+
# TokenStore
|
|
335
|
+
# ---------------------------------------------------------------------------
|
|
336
|
+
|
|
337
|
+
@dataclass(slots=True)
|
|
338
|
+
class TokenStore:
|
|
339
|
+
"""Persistent storage for token records (revocation + introspection)."""
|
|
340
|
+
|
|
341
|
+
engine: Engine
|
|
342
|
+
|
|
343
|
+
def save(
|
|
344
|
+
self,
|
|
345
|
+
jti: str,
|
|
346
|
+
agent_id: str,
|
|
347
|
+
scopes: list[str],
|
|
348
|
+
expires_at: datetime,
|
|
349
|
+
token_type: str = "access",
|
|
350
|
+
) -> None:
|
|
351
|
+
with Session(self.engine) as session:
|
|
352
|
+
row = TokenRecordRow(
|
|
353
|
+
jti=jti,
|
|
354
|
+
agent_id=agent_id,
|
|
355
|
+
scopes_json=json.dumps(scopes),
|
|
356
|
+
token_type=token_type,
|
|
357
|
+
status="active",
|
|
358
|
+
expires_at=expires_at,
|
|
359
|
+
)
|
|
360
|
+
session.add(row)
|
|
361
|
+
session.commit()
|
|
362
|
+
|
|
363
|
+
def revoke(self, jti: str) -> bool:
|
|
364
|
+
with Session(self.engine) as session:
|
|
365
|
+
row = session.get(TokenRecordRow, jti)
|
|
366
|
+
if row is None:
|
|
367
|
+
return False
|
|
368
|
+
row.status = "revoked"
|
|
369
|
+
session.commit()
|
|
370
|
+
return True
|
|
371
|
+
|
|
372
|
+
def get(self, jti: str) -> TokenRecordRow | None:
|
|
373
|
+
with Session(self.engine) as session:
|
|
374
|
+
return session.get(TokenRecordRow, jti)
|
|
375
|
+
|
|
376
|
+
def is_revoked(self, jti: str) -> bool:
|
|
377
|
+
with Session(self.engine) as session:
|
|
378
|
+
row = session.get(TokenRecordRow, jti)
|
|
379
|
+
if row is None:
|
|
380
|
+
return False
|
|
381
|
+
return row.status == "revoked"
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
# ---------------------------------------------------------------------------
|
|
385
|
+
# AuditStore
|
|
386
|
+
# ---------------------------------------------------------------------------
|
|
387
|
+
|
|
388
|
+
@dataclass(slots=True)
|
|
389
|
+
class AuditStore:
|
|
390
|
+
"""Persistent audit event log."""
|
|
391
|
+
|
|
392
|
+
engine: Engine
|
|
393
|
+
|
|
394
|
+
def log(self, event_type: str, agent_id: str | None = None,
|
|
395
|
+
tool_name: str | None = None, details: dict[str, Any] | None = None) -> None:
|
|
396
|
+
with Session(self.engine) as session:
|
|
397
|
+
row = AuditEventRow(
|
|
398
|
+
event_type=event_type,
|
|
399
|
+
agent_id=agent_id,
|
|
400
|
+
tool_name=tool_name,
|
|
401
|
+
details_json=json.dumps(details or {}, default=str),
|
|
402
|
+
)
|
|
403
|
+
session.add(row)
|
|
404
|
+
session.commit()
|
|
405
|
+
|
|
406
|
+
def list_events(self, limit: int = 100) -> list[dict[str, Any]]:
|
|
407
|
+
with Session(self.engine) as session:
|
|
408
|
+
rows = session.scalars(
|
|
409
|
+
select(AuditEventRow).order_by(AuditEventRow.id.desc()).limit(limit)
|
|
410
|
+
).all()
|
|
411
|
+
return [
|
|
412
|
+
{
|
|
413
|
+
"id": r.id,
|
|
414
|
+
"event_type": r.event_type,
|
|
415
|
+
"agent_id": r.agent_id,
|
|
416
|
+
"tool_name": r.tool_name,
|
|
417
|
+
"details": json.loads(r.details_json) if r.details_json else {},
|
|
418
|
+
"timestamp": r.timestamp.isoformat() if r.timestamp else None,
|
|
419
|
+
}
|
|
420
|
+
for r in reversed(rows)
|
|
421
|
+
]
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
# ---------------------------------------------------------------------------
|
|
425
|
+
# UserStore
|
|
426
|
+
# ---------------------------------------------------------------------------
|
|
427
|
+
|
|
428
|
+
def _hash_password(password: str) -> str:
|
|
429
|
+
return _bcrypt.hashpw(password.encode(), _bcrypt.gensalt()).decode()
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def _verify_password(password: str, hashed: str) -> bool:
|
|
433
|
+
return _bcrypt.checkpw(password.encode(), hashed.encode())
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
@dataclass(slots=True)
|
|
437
|
+
class UserStore:
|
|
438
|
+
"""Persistent storage for human user accounts and password reset tokens.
|
|
439
|
+
|
|
440
|
+
Password hashing is handled internally — callers pass plaintext passwords.
|
|
441
|
+
"""
|
|
442
|
+
|
|
443
|
+
engine: Engine
|
|
444
|
+
|
|
445
|
+
def _row_to_user(self, row: UserRow) -> User:
|
|
446
|
+
return User(
|
|
447
|
+
id=row.id,
|
|
448
|
+
email=row.email,
|
|
449
|
+
name=row.name,
|
|
450
|
+
role=row.role,
|
|
451
|
+
created_at=row.created_at,
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
def create_user(
|
|
455
|
+
self,
|
|
456
|
+
email: str,
|
|
457
|
+
password: str,
|
|
458
|
+
name: str | None = None,
|
|
459
|
+
role: str = "user",
|
|
460
|
+
) -> User:
|
|
461
|
+
hashed = _hash_password(password)
|
|
462
|
+
with Session(self.engine) as session:
|
|
463
|
+
row = UserRow(
|
|
464
|
+
id=str(uuid.uuid4()),
|
|
465
|
+
email=email.lower(),
|
|
466
|
+
name=name,
|
|
467
|
+
hashed_password=hashed,
|
|
468
|
+
role=role,
|
|
469
|
+
)
|
|
470
|
+
session.add(row)
|
|
471
|
+
session.commit()
|
|
472
|
+
session.refresh(row)
|
|
473
|
+
return self._row_to_user(row)
|
|
474
|
+
|
|
475
|
+
def authenticate(self, email: str, password: str) -> User | None:
|
|
476
|
+
"""Verify email + password. Returns User on success, None on failure."""
|
|
477
|
+
with Session(self.engine) as session:
|
|
478
|
+
row = session.scalar(
|
|
479
|
+
select(UserRow).where(UserRow.email == email.lower())
|
|
480
|
+
)
|
|
481
|
+
if row is None:
|
|
482
|
+
return None
|
|
483
|
+
if not _verify_password(password, row.hashed_password):
|
|
484
|
+
return None
|
|
485
|
+
return self._row_to_user(row)
|
|
486
|
+
|
|
487
|
+
def get_by_id(self, user_id: str) -> User | None:
|
|
488
|
+
with Session(self.engine) as session:
|
|
489
|
+
row = session.get(UserRow, user_id)
|
|
490
|
+
if row is None:
|
|
491
|
+
return None
|
|
492
|
+
return self._row_to_user(row)
|
|
493
|
+
|
|
494
|
+
def get_by_email(self, email: str) -> User | None:
|
|
495
|
+
with Session(self.engine) as session:
|
|
496
|
+
row = session.scalar(
|
|
497
|
+
select(UserRow).where(UserRow.email == email.lower())
|
|
498
|
+
)
|
|
499
|
+
if row is None:
|
|
500
|
+
return None
|
|
501
|
+
return self._row_to_user(row)
|
|
502
|
+
|
|
503
|
+
def update_password(self, user_id: str, new_password: str) -> None:
|
|
504
|
+
hashed = _hash_password(new_password)
|
|
505
|
+
with Session(self.engine) as session:
|
|
506
|
+
row = session.get(UserRow, user_id)
|
|
507
|
+
if row:
|
|
508
|
+
row.hashed_password = hashed
|
|
509
|
+
row.updated_at = datetime.now(timezone.utc)
|
|
510
|
+
session.commit()
|
|
511
|
+
|
|
512
|
+
def save_reset_token(self, token: str, user_id: str, expires_at: datetime) -> None:
|
|
513
|
+
with Session(self.engine) as session:
|
|
514
|
+
row = PasswordResetRow(token=token, user_id=user_id, expires_at=expires_at)
|
|
515
|
+
session.add(row)
|
|
516
|
+
session.commit()
|
|
517
|
+
|
|
518
|
+
def validate_reset_token(self, token: str) -> str | None:
|
|
519
|
+
"""Return user_id if the token is valid and unused, else None."""
|
|
520
|
+
with Session(self.engine) as session:
|
|
521
|
+
row = session.get(PasswordResetRow, token)
|
|
522
|
+
if row is None or row.used:
|
|
523
|
+
return None
|
|
524
|
+
expires = row.expires_at
|
|
525
|
+
if expires.tzinfo is None:
|
|
526
|
+
expires = expires.replace(tzinfo=timezone.utc)
|
|
527
|
+
if datetime.now(timezone.utc) > expires:
|
|
528
|
+
return None
|
|
529
|
+
return row.user_id
|
|
530
|
+
|
|
531
|
+
def mark_reset_token_used(self, token: str) -> None:
|
|
532
|
+
with Session(self.engine) as session:
|
|
533
|
+
row = session.get(PasswordResetRow, token)
|
|
534
|
+
if row:
|
|
535
|
+
row.used = True
|
|
536
|
+
session.commit()
|
agent_auth/tokens.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""agent_auth JWT token operations — issue & verify for agents and users."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from datetime import datetime, timedelta, timezone
|
|
6
|
+
from uuid import uuid4
|
|
7
|
+
|
|
8
|
+
from jose import ExpiredSignatureError, JWTError, jwt
|
|
9
|
+
|
|
10
|
+
from agent_auth.exceptions import TokenExpiredError, TokenInvalidError
|
|
11
|
+
from agent_auth.models import TokenClaims
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def issue_token(
|
|
15
|
+
agent_id: str,
|
|
16
|
+
scopes: list[str],
|
|
17
|
+
secret_key: str,
|
|
18
|
+
algorithm: str = "HS256",
|
|
19
|
+
expire_minutes: int = 30,
|
|
20
|
+
) -> tuple[str, str]:
|
|
21
|
+
"""Create a signed access JWT.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
(encoded_jwt, jti)
|
|
25
|
+
"""
|
|
26
|
+
now = datetime.now(timezone.utc)
|
|
27
|
+
jti = str(uuid4())
|
|
28
|
+
payload = {
|
|
29
|
+
"sub": agent_id,
|
|
30
|
+
"jti": jti,
|
|
31
|
+
"scopes": scopes,
|
|
32
|
+
"type": "access",
|
|
33
|
+
"iat": now,
|
|
34
|
+
"exp": now + timedelta(minutes=expire_minutes),
|
|
35
|
+
}
|
|
36
|
+
encoded = jwt.encode(payload, secret_key, algorithm=algorithm)
|
|
37
|
+
return encoded, jti
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def issue_refresh_token(
|
|
41
|
+
agent_id: str,
|
|
42
|
+
scopes: list[str],
|
|
43
|
+
secret_key: str,
|
|
44
|
+
algorithm: str = "HS256",
|
|
45
|
+
expire_days: int = 7,
|
|
46
|
+
) -> tuple[str, str]:
|
|
47
|
+
"""Create a signed refresh JWT.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
(encoded_jwt, jti)
|
|
51
|
+
"""
|
|
52
|
+
now = datetime.now(timezone.utc)
|
|
53
|
+
jti = str(uuid4())
|
|
54
|
+
payload = {
|
|
55
|
+
"sub": agent_id,
|
|
56
|
+
"jti": jti,
|
|
57
|
+
"scopes": scopes,
|
|
58
|
+
"type": "refresh",
|
|
59
|
+
"iat": now,
|
|
60
|
+
"exp": now + timedelta(days=expire_days),
|
|
61
|
+
}
|
|
62
|
+
encoded = jwt.encode(payload, secret_key, algorithm=algorithm)
|
|
63
|
+
return encoded, jti
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def verify_token(
|
|
67
|
+
token: str,
|
|
68
|
+
secret_key: str,
|
|
69
|
+
algorithm: str = "HS256",
|
|
70
|
+
) -> TokenClaims:
|
|
71
|
+
"""Decode and validate a JWT (access or refresh).
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
TokenExpiredError: if the token has passed its ``exp``.
|
|
75
|
+
TokenInvalidError: if the signature or structure is invalid.
|
|
76
|
+
"""
|
|
77
|
+
try:
|
|
78
|
+
payload = jwt.decode(token, secret_key, algorithms=[algorithm])
|
|
79
|
+
except ExpiredSignatureError:
|
|
80
|
+
raise TokenExpiredError()
|
|
81
|
+
except JWTError as exc:
|
|
82
|
+
raise TokenInvalidError(str(exc))
|
|
83
|
+
|
|
84
|
+
return TokenClaims(
|
|
85
|
+
agent_id=payload["sub"],
|
|
86
|
+
jti=payload["jti"],
|
|
87
|
+
scopes=payload.get("scopes", []),
|
|
88
|
+
issued_at=datetime.fromtimestamp(payload["iat"], tz=timezone.utc),
|
|
89
|
+
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
|
|
90
|
+
raw=payload,
|
|
91
|
+
token_type=payload.get("type", "access"),
|
|
92
|
+
)
|