agentauthlayer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
agent_auth/storage.py ADDED
@@ -0,0 +1,536 @@
1
+ """agent_auth.storage — SQLAlchemy-backed persistence layer.
2
+
3
+ Tables:
4
+ - agents — agent identities (id, name, owner, role, scopes, status)
5
+ - agent_scopes — granted scopes per agent
6
+ - tools — registered tool catalog
7
+ - token_records — issued tokens (access + refresh) for revocation / introspection
8
+ - audit_events — immutable audit trail
9
+ - users — human user accounts
10
+ - password_reset_tokens — one-time password reset tokens
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import os
17
+ import uuid
18
+ from dataclasses import dataclass
19
+ from datetime import datetime, timezone
20
+ from typing import Any
21
+
22
+ import bcrypt as _bcrypt
23
+ from sqlalchemy import (
24
+ Boolean,
25
+ DateTime,
26
+ ForeignKey,
27
+ Integer,
28
+ String,
29
+ Text,
30
+ UniqueConstraint,
31
+ create_engine,
32
+ delete,
33
+ select,
34
+ )
35
+ from sqlalchemy.engine import Engine
36
+ from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship
37
+ from sqlalchemy.pool import StaticPool
38
+
39
+ from agent_auth.models import Agent, User
40
+
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # ORM base + table definitions
44
+ # ---------------------------------------------------------------------------
45
+
46
+ class Base(DeclarativeBase):
47
+ pass
48
+
49
+
50
+ class AgentRow(Base):
51
+ __tablename__ = "agents"
52
+
53
+ agent_id: Mapped[str] = mapped_column(String(255), primary_key=True)
54
+ name: Mapped[str | None] = mapped_column(String(255), nullable=True)
55
+ owner: Mapped[str | None] = mapped_column(String(255), nullable=True)
56
+ role: Mapped[str | None] = mapped_column(String(255), nullable=True)
57
+ status: Mapped[str] = mapped_column(String(32), default="active")
58
+ metadata_json: Mapped[str | None] = mapped_column(Text, nullable=True)
59
+ created_at: Mapped[datetime] = mapped_column(
60
+ DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
61
+ )
62
+ updated_at: Mapped[datetime] = mapped_column(
63
+ DateTime(timezone=True),
64
+ default=lambda: datetime.now(timezone.utc),
65
+ onupdate=lambda: datetime.now(timezone.utc),
66
+ )
67
+
68
+ scopes: Mapped[list["AgentScopeRow"]] = relationship(
69
+ back_populates="agent",
70
+ cascade="all, delete-orphan",
71
+ lazy="selectin",
72
+ )
73
+
74
+
75
+ class AgentScopeRow(Base):
76
+ __tablename__ = "agent_scopes"
77
+ __table_args__ = (
78
+ UniqueConstraint("agent_id", "scope", name="uq_agent_scope"),
79
+ )
80
+
81
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
82
+ agent_id: Mapped[str] = mapped_column(ForeignKey("agents.agent_id", ondelete="CASCADE"))
83
+ scope: Mapped[str] = mapped_column(String(255), nullable=False)
84
+
85
+ agent: Mapped[AgentRow] = relationship(back_populates="scopes")
86
+
87
+
88
+ class ToolRow(Base):
89
+ __tablename__ = "tools"
90
+
91
+ name: Mapped[str] = mapped_column(String(255), primary_key=True)
92
+ required_scope: Mapped[str] = mapped_column(String(255), nullable=False)
93
+ description: Mapped[str] = mapped_column(Text, default="")
94
+ created_at: Mapped[datetime] = mapped_column(
95
+ DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
96
+ )
97
+
98
+
99
+ class TokenRecordRow(Base):
100
+ __tablename__ = "token_records"
101
+
102
+ jti: Mapped[str] = mapped_column(String(255), primary_key=True)
103
+ agent_id: Mapped[str] = mapped_column(String(255), nullable=False)
104
+ scopes_json: Mapped[str] = mapped_column(Text, default="[]")
105
+ token_type: Mapped[str] = mapped_column(String(32), default="access")
106
+ status: Mapped[str] = mapped_column(String(32), default="active")
107
+ issued_at: Mapped[datetime] = mapped_column(
108
+ DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
109
+ )
110
+ expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
111
+
112
+
113
+ class AuditEventRow(Base):
114
+ __tablename__ = "audit_events"
115
+
116
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
117
+ event_type: Mapped[str] = mapped_column(String(128), nullable=False)
118
+ agent_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
119
+ tool_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
120
+ details_json: Mapped[str] = mapped_column(Text, default="{}")
121
+ timestamp: Mapped[datetime] = mapped_column(
122
+ DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
123
+ )
124
+
125
+
126
+ class UserRow(Base):
127
+ __tablename__ = "users"
128
+
129
+ id: Mapped[str] = mapped_column(String(255), primary_key=True,
130
+ default=lambda: str(uuid.uuid4()))
131
+ email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
132
+ name: Mapped[str | None] = mapped_column(String(255), nullable=True)
133
+ hashed_password: Mapped[str] = mapped_column(Text, nullable=False)
134
+ role: Mapped[str] = mapped_column(String(64), nullable=False, default="user")
135
+ created_at: Mapped[datetime] = mapped_column(
136
+ DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
137
+ )
138
+ updated_at: Mapped[datetime] = mapped_column(
139
+ DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
140
+ )
141
+
142
+
143
+ class PasswordResetRow(Base):
144
+ __tablename__ = "password_reset_tokens"
145
+
146
+ token: Mapped[str] = mapped_column(String(255), primary_key=True)
147
+ user_id: Mapped[str] = mapped_column(String(255), nullable=False)
148
+ expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
149
+ used: Mapped[bool] = mapped_column(Boolean, default=False)
150
+
151
+
152
+ # ---------------------------------------------------------------------------
153
+ # Engine helpers
154
+ # ---------------------------------------------------------------------------
155
+
156
+ def default_database_url() -> str:
157
+ url = os.getenv("DATABASE_URL")
158
+ if url:
159
+ return url
160
+ return "sqlite:///agent_auth.db"
161
+
162
+
163
+ def create_db_engine(database_url: str | None = None) -> Engine:
164
+ url = database_url or default_database_url()
165
+
166
+ connect_args: dict = {}
167
+ engine_kwargs: dict = {"future": True}
168
+
169
+ if url.startswith("sqlite"):
170
+ connect_args["check_same_thread"] = False
171
+
172
+ if ":memory:" in url:
173
+ engine_kwargs["poolclass"] = StaticPool
174
+
175
+ if "file:" in url:
176
+ connect_args["uri"] = True
177
+
178
+ return create_engine(url, connect_args=connect_args, **engine_kwargs)
179
+
180
+
181
+ def init_db(engine: Engine) -> None:
182
+ Base.metadata.create_all(engine)
183
+
184
+
185
+ # ---------------------------------------------------------------------------
186
+ # AgentStore
187
+ # ---------------------------------------------------------------------------
188
+
189
+ @dataclass(slots=True)
190
+ class AgentStore:
191
+ """Persistent storage for agents + scopes."""
192
+
193
+ engine: Engine
194
+
195
+ def init(self) -> None:
196
+ init_db(self.engine)
197
+
198
+ def _row_to_agent(self, row: AgentRow) -> Agent:
199
+ meta = json.loads(row.metadata_json) if row.metadata_json else {}
200
+ return Agent(
201
+ agent_id=row.agent_id,
202
+ scopes=[s.scope for s in row.scopes],
203
+ metadata=meta,
204
+ name=row.name,
205
+ owner=row.owner,
206
+ role=row.role,
207
+ status=row.status,
208
+ created_at=row.created_at,
209
+ )
210
+
211
+ def upsert_agent(
212
+ self,
213
+ agent_id: str,
214
+ scopes: list[str],
215
+ name: str | None = None,
216
+ owner: str | None = None,
217
+ role: str | None = None,
218
+ metadata: dict | None = None,
219
+ ) -> Agent:
220
+ with Session(self.engine) as session:
221
+ row = session.get(AgentRow, agent_id)
222
+ if row is None:
223
+ row = AgentRow(agent_id=agent_id)
224
+ session.add(row)
225
+
226
+ if name is not None:
227
+ row.name = name
228
+ if owner is not None:
229
+ row.owner = owner
230
+ if role is not None:
231
+ row.role = role
232
+ if metadata is not None:
233
+ row.metadata_json = json.dumps(metadata, default=str)
234
+ row.updated_at = datetime.now(timezone.utc)
235
+ session.flush()
236
+
237
+ session.execute(delete(AgentScopeRow).where(AgentScopeRow.agent_id == agent_id))
238
+ session.flush()
239
+ for scope in scopes:
240
+ session.add(AgentScopeRow(agent_id=agent_id, scope=scope))
241
+
242
+ session.commit()
243
+ session.refresh(row)
244
+ return self._row_to_agent(row)
245
+
246
+ def get_agent(self, agent_id: str) -> Agent | None:
247
+ with Session(self.engine) as session:
248
+ row = session.get(AgentRow, agent_id)
249
+ if row is None:
250
+ return None
251
+ return self._row_to_agent(row)
252
+
253
+ def list_agents(self) -> list[Agent]:
254
+ with Session(self.engine) as session:
255
+ rows = session.scalars(select(AgentRow)).all()
256
+ return [self._row_to_agent(r) for r in rows]
257
+
258
+ def grant_scope(self, agent_id: str, scope: str) -> Agent:
259
+ with Session(self.engine) as session:
260
+ row = session.get(AgentRow, agent_id)
261
+ if row is None:
262
+ row = AgentRow(agent_id=agent_id)
263
+ session.add(row)
264
+ session.flush()
265
+
266
+ existing = session.scalar(
267
+ select(AgentScopeRow).where(
268
+ AgentScopeRow.agent_id == agent_id,
269
+ AgentScopeRow.scope == scope,
270
+ )
271
+ )
272
+ if existing is None:
273
+ session.add(AgentScopeRow(agent_id=agent_id, scope=scope))
274
+ session.commit()
275
+ session.refresh(row)
276
+ return self._row_to_agent(row)
277
+
278
+ def revoke_scope(self, agent_id: str, scope: str) -> Agent:
279
+ with Session(self.engine) as session:
280
+ session.execute(
281
+ delete(AgentScopeRow).where(
282
+ AgentScopeRow.agent_id == agent_id,
283
+ AgentScopeRow.scope == scope,
284
+ )
285
+ )
286
+ session.commit()
287
+
288
+ row = session.get(AgentRow, agent_id)
289
+ if row is None:
290
+ return Agent(agent_id=agent_id, scopes=[], metadata={})
291
+ return self._row_to_agent(row)
292
+
293
+ def delete_agent(self, agent_id: str) -> bool:
294
+ with Session(self.engine) as session:
295
+ row = session.get(AgentRow, agent_id)
296
+ if row is None:
297
+ return False
298
+ session.delete(row)
299
+ session.commit()
300
+ return True
301
+
302
+
303
+ # ---------------------------------------------------------------------------
304
+ # ToolStore
305
+ # ---------------------------------------------------------------------------
306
+
307
+ @dataclass(slots=True)
308
+ class ToolStore:
309
+ """Persistent storage for the tool catalog."""
310
+
311
+ engine: Engine
312
+
313
+ def upsert(self, name: str, required_scope: str, description: str = "") -> None:
314
+ with Session(self.engine) as session:
315
+ row = session.get(ToolRow, name)
316
+ if row is None:
317
+ row = ToolRow(name=name, required_scope=required_scope, description=description)
318
+ session.add(row)
319
+ else:
320
+ row.required_scope = required_scope
321
+ row.description = description
322
+ session.commit()
323
+
324
+ def get(self, name: str) -> ToolRow | None:
325
+ with Session(self.engine) as session:
326
+ return session.get(ToolRow, name)
327
+
328
+ def list_all(self) -> list[ToolRow]:
329
+ with Session(self.engine) as session:
330
+ return list(session.scalars(select(ToolRow)).all())
331
+
332
+
333
+ # ---------------------------------------------------------------------------
334
+ # TokenStore
335
+ # ---------------------------------------------------------------------------
336
+
337
+ @dataclass(slots=True)
338
+ class TokenStore:
339
+ """Persistent storage for token records (revocation + introspection)."""
340
+
341
+ engine: Engine
342
+
343
+ def save(
344
+ self,
345
+ jti: str,
346
+ agent_id: str,
347
+ scopes: list[str],
348
+ expires_at: datetime,
349
+ token_type: str = "access",
350
+ ) -> None:
351
+ with Session(self.engine) as session:
352
+ row = TokenRecordRow(
353
+ jti=jti,
354
+ agent_id=agent_id,
355
+ scopes_json=json.dumps(scopes),
356
+ token_type=token_type,
357
+ status="active",
358
+ expires_at=expires_at,
359
+ )
360
+ session.add(row)
361
+ session.commit()
362
+
363
+ def revoke(self, jti: str) -> bool:
364
+ with Session(self.engine) as session:
365
+ row = session.get(TokenRecordRow, jti)
366
+ if row is None:
367
+ return False
368
+ row.status = "revoked"
369
+ session.commit()
370
+ return True
371
+
372
+ def get(self, jti: str) -> TokenRecordRow | None:
373
+ with Session(self.engine) as session:
374
+ return session.get(TokenRecordRow, jti)
375
+
376
+ def is_revoked(self, jti: str) -> bool:
377
+ with Session(self.engine) as session:
378
+ row = session.get(TokenRecordRow, jti)
379
+ if row is None:
380
+ return False
381
+ return row.status == "revoked"
382
+
383
+
384
+ # ---------------------------------------------------------------------------
385
+ # AuditStore
386
+ # ---------------------------------------------------------------------------
387
+
388
+ @dataclass(slots=True)
389
+ class AuditStore:
390
+ """Persistent audit event log."""
391
+
392
+ engine: Engine
393
+
394
+ def log(self, event_type: str, agent_id: str | None = None,
395
+ tool_name: str | None = None, details: dict[str, Any] | None = None) -> None:
396
+ with Session(self.engine) as session:
397
+ row = AuditEventRow(
398
+ event_type=event_type,
399
+ agent_id=agent_id,
400
+ tool_name=tool_name,
401
+ details_json=json.dumps(details or {}, default=str),
402
+ )
403
+ session.add(row)
404
+ session.commit()
405
+
406
+ def list_events(self, limit: int = 100) -> list[dict[str, Any]]:
407
+ with Session(self.engine) as session:
408
+ rows = session.scalars(
409
+ select(AuditEventRow).order_by(AuditEventRow.id.desc()).limit(limit)
410
+ ).all()
411
+ return [
412
+ {
413
+ "id": r.id,
414
+ "event_type": r.event_type,
415
+ "agent_id": r.agent_id,
416
+ "tool_name": r.tool_name,
417
+ "details": json.loads(r.details_json) if r.details_json else {},
418
+ "timestamp": r.timestamp.isoformat() if r.timestamp else None,
419
+ }
420
+ for r in reversed(rows)
421
+ ]
422
+
423
+
424
+ # ---------------------------------------------------------------------------
425
+ # UserStore
426
+ # ---------------------------------------------------------------------------
427
+
428
+ def _hash_password(password: str) -> str:
429
+ return _bcrypt.hashpw(password.encode(), _bcrypt.gensalt()).decode()
430
+
431
+
432
+ def _verify_password(password: str, hashed: str) -> bool:
433
+ return _bcrypt.checkpw(password.encode(), hashed.encode())
434
+
435
+
436
+ @dataclass(slots=True)
437
+ class UserStore:
438
+ """Persistent storage for human user accounts and password reset tokens.
439
+
440
+ Password hashing is handled internally — callers pass plaintext passwords.
441
+ """
442
+
443
+ engine: Engine
444
+
445
+ def _row_to_user(self, row: UserRow) -> User:
446
+ return User(
447
+ id=row.id,
448
+ email=row.email,
449
+ name=row.name,
450
+ role=row.role,
451
+ created_at=row.created_at,
452
+ )
453
+
454
+ def create_user(
455
+ self,
456
+ email: str,
457
+ password: str,
458
+ name: str | None = None,
459
+ role: str = "user",
460
+ ) -> User:
461
+ hashed = _hash_password(password)
462
+ with Session(self.engine) as session:
463
+ row = UserRow(
464
+ id=str(uuid.uuid4()),
465
+ email=email.lower(),
466
+ name=name,
467
+ hashed_password=hashed,
468
+ role=role,
469
+ )
470
+ session.add(row)
471
+ session.commit()
472
+ session.refresh(row)
473
+ return self._row_to_user(row)
474
+
475
+ def authenticate(self, email: str, password: str) -> User | None:
476
+ """Verify email + password. Returns User on success, None on failure."""
477
+ with Session(self.engine) as session:
478
+ row = session.scalar(
479
+ select(UserRow).where(UserRow.email == email.lower())
480
+ )
481
+ if row is None:
482
+ return None
483
+ if not _verify_password(password, row.hashed_password):
484
+ return None
485
+ return self._row_to_user(row)
486
+
487
+ def get_by_id(self, user_id: str) -> User | None:
488
+ with Session(self.engine) as session:
489
+ row = session.get(UserRow, user_id)
490
+ if row is None:
491
+ return None
492
+ return self._row_to_user(row)
493
+
494
+ def get_by_email(self, email: str) -> User | None:
495
+ with Session(self.engine) as session:
496
+ row = session.scalar(
497
+ select(UserRow).where(UserRow.email == email.lower())
498
+ )
499
+ if row is None:
500
+ return None
501
+ return self._row_to_user(row)
502
+
503
+ def update_password(self, user_id: str, new_password: str) -> None:
504
+ hashed = _hash_password(new_password)
505
+ with Session(self.engine) as session:
506
+ row = session.get(UserRow, user_id)
507
+ if row:
508
+ row.hashed_password = hashed
509
+ row.updated_at = datetime.now(timezone.utc)
510
+ session.commit()
511
+
512
+ def save_reset_token(self, token: str, user_id: str, expires_at: datetime) -> None:
513
+ with Session(self.engine) as session:
514
+ row = PasswordResetRow(token=token, user_id=user_id, expires_at=expires_at)
515
+ session.add(row)
516
+ session.commit()
517
+
518
+ def validate_reset_token(self, token: str) -> str | None:
519
+ """Return user_id if the token is valid and unused, else None."""
520
+ with Session(self.engine) as session:
521
+ row = session.get(PasswordResetRow, token)
522
+ if row is None or row.used:
523
+ return None
524
+ expires = row.expires_at
525
+ if expires.tzinfo is None:
526
+ expires = expires.replace(tzinfo=timezone.utc)
527
+ if datetime.now(timezone.utc) > expires:
528
+ return None
529
+ return row.user_id
530
+
531
+ def mark_reset_token_used(self, token: str) -> None:
532
+ with Session(self.engine) as session:
533
+ row = session.get(PasswordResetRow, token)
534
+ if row:
535
+ row.used = True
536
+ session.commit()
agent_auth/tokens.py ADDED
@@ -0,0 +1,92 @@
1
+ """agent_auth JWT token operations — issue & verify for agents and users."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from datetime import datetime, timedelta, timezone
6
+ from uuid import uuid4
7
+
8
+ from jose import ExpiredSignatureError, JWTError, jwt
9
+
10
+ from agent_auth.exceptions import TokenExpiredError, TokenInvalidError
11
+ from agent_auth.models import TokenClaims
12
+
13
+
14
+ def issue_token(
15
+ agent_id: str,
16
+ scopes: list[str],
17
+ secret_key: str,
18
+ algorithm: str = "HS256",
19
+ expire_minutes: int = 30,
20
+ ) -> tuple[str, str]:
21
+ """Create a signed access JWT.
22
+
23
+ Returns:
24
+ (encoded_jwt, jti)
25
+ """
26
+ now = datetime.now(timezone.utc)
27
+ jti = str(uuid4())
28
+ payload = {
29
+ "sub": agent_id,
30
+ "jti": jti,
31
+ "scopes": scopes,
32
+ "type": "access",
33
+ "iat": now,
34
+ "exp": now + timedelta(minutes=expire_minutes),
35
+ }
36
+ encoded = jwt.encode(payload, secret_key, algorithm=algorithm)
37
+ return encoded, jti
38
+
39
+
40
+ def issue_refresh_token(
41
+ agent_id: str,
42
+ scopes: list[str],
43
+ secret_key: str,
44
+ algorithm: str = "HS256",
45
+ expire_days: int = 7,
46
+ ) -> tuple[str, str]:
47
+ """Create a signed refresh JWT.
48
+
49
+ Returns:
50
+ (encoded_jwt, jti)
51
+ """
52
+ now = datetime.now(timezone.utc)
53
+ jti = str(uuid4())
54
+ payload = {
55
+ "sub": agent_id,
56
+ "jti": jti,
57
+ "scopes": scopes,
58
+ "type": "refresh",
59
+ "iat": now,
60
+ "exp": now + timedelta(days=expire_days),
61
+ }
62
+ encoded = jwt.encode(payload, secret_key, algorithm=algorithm)
63
+ return encoded, jti
64
+
65
+
66
+ def verify_token(
67
+ token: str,
68
+ secret_key: str,
69
+ algorithm: str = "HS256",
70
+ ) -> TokenClaims:
71
+ """Decode and validate a JWT (access or refresh).
72
+
73
+ Raises:
74
+ TokenExpiredError: if the token has passed its ``exp``.
75
+ TokenInvalidError: if the signature or structure is invalid.
76
+ """
77
+ try:
78
+ payload = jwt.decode(token, secret_key, algorithms=[algorithm])
79
+ except ExpiredSignatureError:
80
+ raise TokenExpiredError()
81
+ except JWTError as exc:
82
+ raise TokenInvalidError(str(exc))
83
+
84
+ return TokenClaims(
85
+ agent_id=payload["sub"],
86
+ jti=payload["jti"],
87
+ scopes=payload.get("scopes", []),
88
+ issued_at=datetime.fromtimestamp(payload["iat"], tz=timezone.utc),
89
+ expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
90
+ raw=payload,
91
+ token_type=payload.get("type", "access"),
92
+ )