paskia 0.7.1__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. paskia/_version.py +2 -2
  2. paskia/authsession.py +12 -49
  3. paskia/bootstrap.py +30 -25
  4. paskia/db/__init__.py +163 -401
  5. paskia/db/background.py +128 -0
  6. paskia/db/jsonl.py +132 -0
  7. paskia/db/operations.py +1241 -0
  8. paskia/db/structs.py +148 -0
  9. paskia/fastapi/admin.py +456 -215
  10. paskia/fastapi/api.py +16 -15
  11. paskia/fastapi/authz.py +7 -2
  12. paskia/fastapi/mainapp.py +2 -1
  13. paskia/fastapi/remote.py +20 -20
  14. paskia/fastapi/reset.py +9 -10
  15. paskia/fastapi/user.py +10 -18
  16. paskia/fastapi/ws.py +22 -19
  17. paskia/frontend-build/auth/admin/index.html +3 -3
  18. paskia/frontend-build/auth/assets/AccessDenied-aTdCvz9k.js +8 -0
  19. paskia/frontend-build/auth/assets/admin-BeNu48FR.css +1 -0
  20. paskia/frontend-build/auth/assets/admin-tVs8oyLv.js +1 -0
  21. paskia/frontend-build/auth/assets/{auth-BU_O38k2.css → auth-BKX7shEe.css} +1 -1
  22. paskia/frontend-build/auth/assets/auth-Dk3q4pNS.js +1 -0
  23. paskia/frontend-build/auth/index.html +3 -3
  24. paskia/globals.py +7 -10
  25. paskia/migrate/__init__.py +274 -0
  26. paskia/migrate/sql.py +381 -0
  27. paskia/util/permutil.py +16 -5
  28. paskia/util/sessionutil.py +3 -2
  29. paskia/util/userinfo.py +12 -26
  30. paskia-0.8.0.dist-info/METADATA +94 -0
  31. {paskia-0.7.1.dist-info → paskia-0.8.0.dist-info}/RECORD +33 -29
  32. {paskia-0.7.1.dist-info → paskia-0.8.0.dist-info}/entry_points.txt +1 -0
  33. paskia/db/sql.py +0 -1424
  34. paskia/frontend-build/auth/assets/AccessDenied-C-lL9vbN.js +0 -8
  35. paskia/frontend-build/auth/assets/admin-Cs6Mg773.css +0 -1
  36. paskia/frontend-build/auth/assets/admin-Df5_Damp.js +0 -1
  37. paskia/frontend-build/auth/assets/auth-Df3pjeSS.js +0 -1
  38. paskia/util/tokens.py +0 -44
  39. paskia-0.7.1.dist-info/METADATA +0 -22
  40. {paskia-0.7.1.dist-info → paskia-0.8.0.dist-info}/WHEEL +0 -0
paskia/db/sql.py DELETED
@@ -1,1424 +0,0 @@
1
- """
2
- Async database implementation for WebAuthn passkey authentication.
3
-
4
- This module provides an async database layer using SQLAlchemy async mode
5
- for managing users and credentials in a WebAuthn authentication system.
6
- """
7
-
8
- import os
9
- from contextlib import asynccontextmanager
10
- from datetime import datetime, timezone
11
- from uuid import UUID
12
-
13
- from sqlalchemy import (
14
- DateTime,
15
- ForeignKey,
16
- Integer,
17
- LargeBinary,
18
- String,
19
- delete,
20
- event,
21
- insert,
22
- select,
23
- text,
24
- update,
25
- )
26
- from sqlalchemy.dialects.sqlite import BLOB
27
- from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
28
- from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
29
-
30
- from paskia.config import SESSION_LIFETIME
31
- from paskia.db import (
32
- Credential,
33
- DatabaseInterface,
34
- Org,
35
- Permission,
36
- ResetToken,
37
- Role,
38
- Session,
39
- SessionContext,
40
- User,
41
- )
42
- from paskia.globals import db
43
-
44
- DB_PATH_DEFAULT = "sqlite+aiosqlite:///paskia.sqlite"
45
-
46
-
47
- def _normalize_dt(value: datetime | None) -> datetime | None:
48
- if value is None:
49
- return None
50
- if value.tzinfo is None:
51
- return value.replace(tzinfo=timezone.utc)
52
- return value.astimezone(timezone.utc)
53
-
54
-
55
- async def init(*args, **kwargs):
56
- db_path = os.environ.get("PASKIA_DB", DB_PATH_DEFAULT)
57
- db.instance = DB(db_path)
58
- await db.instance.init_db()
59
-
60
-
61
- class Base(DeclarativeBase):
62
- pass
63
-
64
-
65
- class OrgModel(Base):
66
- __tablename__ = "orgs"
67
-
68
- uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
69
- display_name: Mapped[str] = mapped_column(String, nullable=False)
70
-
71
- def as_dataclass(self):
72
- # Base Org without permissions/roles (filled by data accessors)
73
- return Org(UUID(bytes=self.uuid), self.display_name)
74
-
75
- @staticmethod
76
- def from_dataclass(org: Org):
77
- return OrgModel(uuid=org.uuid.bytes, display_name=org.display_name)
78
-
79
-
80
- class RoleModel(Base):
81
- __tablename__ = "roles"
82
-
83
- uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
84
- org_uuid: Mapped[bytes] = mapped_column(
85
- LargeBinary(16), ForeignKey("orgs.uuid", ondelete="CASCADE"), nullable=False
86
- )
87
- display_name: Mapped[str] = mapped_column(String, nullable=False)
88
-
89
- def as_dataclass(self):
90
- # Base Role without permissions (filled by data accessors)
91
- return Role(
92
- uuid=UUID(bytes=self.uuid),
93
- org_uuid=UUID(bytes=self.org_uuid),
94
- display_name=self.display_name,
95
- )
96
-
97
- @staticmethod
98
- def from_dataclass(role: Role):
99
- return RoleModel(
100
- uuid=role.uuid.bytes,
101
- org_uuid=role.org_uuid.bytes,
102
- display_name=role.display_name,
103
- )
104
-
105
-
106
- class UserModel(Base):
107
- __tablename__ = "users"
108
-
109
- uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
110
- display_name: Mapped[str] = mapped_column(String, nullable=False)
111
- role_uuid: Mapped[bytes] = mapped_column(
112
- LargeBinary(16), ForeignKey("roles.uuid", ondelete="CASCADE"), nullable=False
113
- )
114
- created_at: Mapped[datetime] = mapped_column(
115
- DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
116
- )
117
- last_seen: Mapped[datetime | None] = mapped_column(
118
- DateTime(timezone=True), nullable=True
119
- )
120
- visits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
121
-
122
- def as_dataclass(self) -> User:
123
- return User(
124
- uuid=UUID(bytes=self.uuid),
125
- display_name=self.display_name,
126
- role_uuid=UUID(bytes=self.role_uuid),
127
- created_at=_normalize_dt(self.created_at) or self.created_at,
128
- last_seen=_normalize_dt(self.last_seen) or self.last_seen,
129
- visits=self.visits,
130
- )
131
-
132
- @staticmethod
133
- def from_dataclass(user: User):
134
- return UserModel(
135
- uuid=user.uuid.bytes,
136
- display_name=user.display_name,
137
- role_uuid=user.role_uuid.bytes,
138
- created_at=user.created_at or datetime.now(timezone.utc),
139
- last_seen=user.last_seen,
140
- visits=user.visits,
141
- )
142
-
143
-
144
- class CredentialModel(Base):
145
- __tablename__ = "credentials"
146
-
147
- uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
148
- credential_id: Mapped[bytes] = mapped_column(
149
- LargeBinary(64), unique=True, index=True
150
- )
151
- user_uuid: Mapped[bytes] = mapped_column(
152
- LargeBinary(16), ForeignKey("users.uuid", ondelete="CASCADE")
153
- )
154
- aaguid: Mapped[bytes] = mapped_column(LargeBinary(16), nullable=False)
155
- public_key: Mapped[bytes] = mapped_column(BLOB, nullable=False)
156
- sign_count: Mapped[int] = mapped_column(Integer, nullable=False)
157
- created_at: Mapped[datetime] = mapped_column(
158
- DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
159
- )
160
- # Columns declared timezone-aware going forward; legacy rows may still be naive in storage
161
- last_used: Mapped[datetime | None] = mapped_column(
162
- DateTime(timezone=True), nullable=True
163
- )
164
- last_verified: Mapped[datetime | None] = mapped_column(
165
- DateTime(timezone=True), nullable=True
166
- )
167
-
168
- def as_dataclass(self): # type: ignore[override]
169
- return Credential(
170
- uuid=UUID(bytes=self.uuid),
171
- credential_id=self.credential_id,
172
- user_uuid=UUID(bytes=self.user_uuid),
173
- aaguid=UUID(bytes=self.aaguid),
174
- public_key=self.public_key,
175
- sign_count=self.sign_count,
176
- created_at=_normalize_dt(self.created_at) or self.created_at,
177
- last_used=_normalize_dt(self.last_used) or self.last_used,
178
- last_verified=_normalize_dt(self.last_verified) or self.last_verified,
179
- )
180
-
181
-
182
- class SessionModel(Base):
183
- __tablename__ = "sessions"
184
-
185
- key: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
186
- user_uuid: Mapped[bytes] = mapped_column(
187
- LargeBinary(16), ForeignKey("users.uuid", ondelete="CASCADE"), nullable=False
188
- )
189
- credential_uuid: Mapped[bytes] = mapped_column(
190
- LargeBinary(16),
191
- ForeignKey("credentials.uuid", ondelete="CASCADE"),
192
- nullable=False,
193
- )
194
- host: Mapped[str] = mapped_column(String, nullable=False)
195
- ip: Mapped[str] = mapped_column(String(64), nullable=False)
196
- user_agent: Mapped[str] = mapped_column(String(512), nullable=False)
197
- renewed: Mapped[datetime] = mapped_column(
198
- DateTime(timezone=True),
199
- default=lambda: datetime.now(timezone.utc),
200
- nullable=False,
201
- )
202
-
203
- def as_dataclass(self):
204
- return Session(
205
- key=self.key,
206
- user_uuid=UUID(bytes=self.user_uuid),
207
- credential_uuid=UUID(bytes=self.credential_uuid),
208
- host=self.host,
209
- ip=self.ip,
210
- user_agent=self.user_agent,
211
- renewed=_normalize_dt(self.renewed) or self.renewed,
212
- )
213
-
214
- @staticmethod
215
- def from_dataclass(session: Session):
216
- return SessionModel(
217
- key=session.key,
218
- user_uuid=session.user_uuid.bytes,
219
- credential_uuid=session.credential_uuid.bytes,
220
- host=session.host,
221
- ip=session.ip,
222
- user_agent=session.user_agent,
223
- renewed=session.renewed,
224
- )
225
-
226
-
227
- class ResetTokenModel(Base):
228
- __tablename__ = "reset_tokens"
229
-
230
- key: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
231
- user_uuid: Mapped[bytes] = mapped_column(
232
- LargeBinary(16), ForeignKey("users.uuid", ondelete="CASCADE"), nullable=False
233
- )
234
- token_type: Mapped[str] = mapped_column(String, nullable=False)
235
- expiry: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
236
-
237
- def as_dataclass(self) -> ResetToken:
238
- return ResetToken(
239
- key=self.key,
240
- user_uuid=UUID(bytes=self.user_uuid),
241
- token_type=self.token_type,
242
- expiry=_normalize_dt(self.expiry) or self.expiry,
243
- )
244
-
245
-
246
- class PermissionModel(Base):
247
- __tablename__ = "permissions"
248
-
249
- id: Mapped[str] = mapped_column(String(64), primary_key=True)
250
- display_name: Mapped[str] = mapped_column(String, nullable=False)
251
-
252
- def as_dataclass(self):
253
- return Permission(self.id, self.display_name)
254
-
255
- @staticmethod
256
- def from_dataclass(permission: Permission):
257
- return PermissionModel(id=permission.id, display_name=permission.display_name)
258
-
259
-
260
- ## Join tables (no dataclass equivalents)
261
-
262
-
263
- class OrgPermission(Base):
264
- """Permissions each organization is allowed to grant to its roles."""
265
-
266
- __tablename__ = "org_permissions"
267
-
268
- id: Mapped[int] = mapped_column(Integer, primary_key=True) # Not used
269
- org_uuid: Mapped[bytes] = mapped_column(
270
- LargeBinary(16), ForeignKey("orgs.uuid", ondelete="CASCADE")
271
- )
272
- permission_id: Mapped[str] = mapped_column(
273
- String(64), ForeignKey("permissions.id", ondelete="CASCADE")
274
- )
275
-
276
-
277
- class RolePermission(Base):
278
- """Permissions that each role grants to its members."""
279
-
280
- __tablename__ = "role_permissions"
281
-
282
- id: Mapped[int] = mapped_column(Integer, primary_key=True) # Not used
283
- role_uuid: Mapped[bytes] = mapped_column(
284
- LargeBinary(16), ForeignKey("roles.uuid", ondelete="CASCADE")
285
- )
286
- permission_id: Mapped[str] = mapped_column(
287
- String(64), ForeignKey("permissions.id", ondelete="CASCADE")
288
- )
289
-
290
-
291
- class DB(DatabaseInterface):
292
- """Database class that handles its own connections."""
293
-
294
- def __init__(self, db_path: str = DB_PATH_DEFAULT):
295
- """Initialize with database path."""
296
- self.engine = create_async_engine(db_path, echo=False)
297
- # Ensure SQLite foreign key enforcement is ON for every new connection
298
- if db_path.startswith("sqlite"):
299
-
300
- @event.listens_for(self.engine.sync_engine, "connect")
301
- def _fk_on(dbapi_connection, connection_record): # type: ignore
302
- try:
303
- cursor = dbapi_connection.cursor()
304
- cursor.execute("PRAGMA foreign_keys=ON;")
305
- cursor.close()
306
- except Exception:
307
- pass
308
-
309
- self.async_session_factory = async_sessionmaker(
310
- self.engine, expire_on_commit=False
311
- )
312
-
313
- @asynccontextmanager
314
- async def session(self):
315
- """Async context manager that provides a database session with transaction."""
316
- async with self.async_session_factory() as session:
317
- async with session.begin():
318
- yield session
319
- await session.flush()
320
- await session.commit()
321
-
322
- async def init_db(self) -> None:
323
- """Initialize database tables."""
324
- async with self.engine.begin() as conn:
325
- await conn.run_sync(Base.metadata.create_all)
326
- result = await conn.execute(text("PRAGMA table_info('sessions')"))
327
- columns = {row[1] for row in result}
328
- expected = {
329
- "key",
330
- "user_uuid",
331
- "credential_uuid",
332
- "host",
333
- "ip",
334
- "user_agent",
335
- "renewed",
336
- }
337
- needs_recreate = False
338
- if columns and columns != expected:
339
- await conn.execute(text("DROP TABLE sessions"))
340
- needs_recreate = True
341
- result = await conn.execute(text("PRAGMA table_info('reset_tokens')"))
342
- if not list(result):
343
- needs_recreate = True
344
- if needs_recreate:
345
- await conn.run_sync(Base.metadata.create_all)
346
- # Run one-time migration to add UTC tzinfo to any naive datetimes
347
- await self._migrate_naive_datetimes()
348
-
349
- async def _migrate_naive_datetimes(self) -> None:
350
- """Attach UTC tzinfo to any legacy naive datetime rows.
351
-
352
- SQLite stores datetimes as text; older rows may have been inserted naive.
353
- We treat naive timestamps as already UTC and rewrite them in ISO8601 with Z.
354
- """
355
- # Helper SQL fragment for detecting naive (no timezone offset) for ISO strings
356
- # We only update rows whose textual representation lacks a 'Z' or '+' sign.
357
- async with self.session() as session:
358
- # Users
359
- for model, fields in [
360
- (UserModel, ["created_at", "last_seen"]),
361
- (CredentialModel, ["created_at", "last_used", "last_verified"]),
362
- (SessionModel, ["renewed"]),
363
- (ResetTokenModel, ["expiry"]),
364
- ]:
365
- stmt = select(model)
366
- result = await session.execute(stmt)
367
- rows = result.scalars().all()
368
- dirty = False
369
- for row in rows:
370
- for fname in fields:
371
- value = getattr(row, fname, None)
372
- if isinstance(value, datetime) and value.tzinfo is None:
373
- setattr(row, fname, value.replace(tzinfo=timezone.utc))
374
- dirty = True
375
- if dirty:
376
- # SQLAlchemy autoflush/commit in context manager will persist
377
- pass
378
-
379
- async def get_user_by_uuid(self, user_uuid: UUID) -> User:
380
- async with self.session() as session:
381
- stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes)
382
- result = await session.execute(stmt)
383
- user_model = result.scalar_one_or_none()
384
-
385
- if user_model:
386
- return user_model.as_dataclass()
387
- raise ValueError("User not found")
388
-
389
- async def create_user(self, user: User) -> None:
390
- async with self.session() as session:
391
- session.add(UserModel.from_dataclass(user))
392
-
393
- async def update_user_display_name(
394
- self, user_uuid: UUID, display_name: str
395
- ) -> None:
396
- async with self.session() as session:
397
- stmt = (
398
- update(UserModel)
399
- .where(UserModel.uuid == user_uuid.bytes)
400
- .values(display_name=display_name)
401
- )
402
- result = await session.execute(stmt)
403
- if result.rowcount == 0: # type: ignore[attr-defined]
404
- raise ValueError("User not found")
405
-
406
- async def create_role(self, role: Role) -> None:
407
- async with self.session() as session:
408
- # Create role record
409
- session.add(RoleModel.from_dataclass(role))
410
- # Persist role permissions
411
- if role.permissions:
412
- for perm_id in role.permissions:
413
- session.add(
414
- RolePermission(
415
- role_uuid=role.uuid.bytes,
416
- permission_id=perm_id,
417
- )
418
- )
419
-
420
- async def create_credential(self, credential: Credential) -> None:
421
- async with self.session() as session:
422
- credential_model = CredentialModel(
423
- uuid=credential.uuid.bytes,
424
- credential_id=credential.credential_id,
425
- user_uuid=credential.user_uuid.bytes,
426
- aaguid=credential.aaguid.bytes,
427
- public_key=credential.public_key,
428
- sign_count=credential.sign_count,
429
- created_at=credential.created_at,
430
- last_used=credential.last_used,
431
- last_verified=credential.last_verified,
432
- )
433
- session.add(credential_model)
434
-
435
- async def get_credential_by_id(self, credential_id: bytes) -> Credential:
436
- async with self.session() as session:
437
- stmt = select(CredentialModel).where(
438
- CredentialModel.credential_id == credential_id
439
- )
440
- result = await session.execute(stmt)
441
- credential_model = result.scalar_one_or_none()
442
-
443
- if not credential_model:
444
- raise ValueError("Credential not found")
445
- return Credential(
446
- uuid=UUID(bytes=credential_model.uuid),
447
- credential_id=credential_model.credential_id,
448
- user_uuid=UUID(bytes=credential_model.user_uuid),
449
- aaguid=UUID(bytes=credential_model.aaguid),
450
- public_key=credential_model.public_key,
451
- sign_count=credential_model.sign_count,
452
- created_at=credential_model.created_at,
453
- last_used=credential_model.last_used,
454
- last_verified=credential_model.last_verified,
455
- )
456
-
457
- async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]:
458
- async with self.session() as session:
459
- stmt = select(CredentialModel.credential_id).where(
460
- CredentialModel.user_uuid == user_uuid.bytes
461
- )
462
- result = await session.execute(stmt)
463
- return [row[0] for row in result.fetchall()]
464
-
465
- async def update_credential(self, credential: Credential) -> None:
466
- async with self.session() as session:
467
- stmt = (
468
- update(CredentialModel)
469
- .where(CredentialModel.credential_id == credential.credential_id)
470
- .values(
471
- sign_count=credential.sign_count,
472
- created_at=credential.created_at,
473
- last_used=credential.last_used,
474
- last_verified=credential.last_verified,
475
- )
476
- )
477
- await session.execute(stmt)
478
-
479
- async def login(self, user_uuid: UUID, credential: Credential) -> None:
480
- async with self.session() as session:
481
- # Update credential
482
- stmt = (
483
- update(CredentialModel)
484
- .where(CredentialModel.credential_id == credential.credential_id)
485
- .values(
486
- sign_count=credential.sign_count,
487
- created_at=credential.created_at,
488
- last_used=credential.last_used,
489
- last_verified=credential.last_verified,
490
- )
491
- )
492
- await session.execute(stmt)
493
-
494
- # Update user's last_seen and increment visits
495
- stmt = (
496
- update(UserModel)
497
- .where(UserModel.uuid == user_uuid.bytes)
498
- .values(last_seen=credential.last_used, visits=UserModel.visits + 1)
499
- )
500
- await session.execute(stmt)
501
-
502
- async def create_user_and_credential(
503
- self, user: User, credential: Credential
504
- ) -> None:
505
- async with self.session() as session:
506
- # Create user
507
- user_model = UserModel.from_dataclass(user)
508
- session.add(user_model)
509
-
510
- # Create credential
511
- credential_model = CredentialModel(
512
- uuid=credential.uuid.bytes,
513
- credential_id=credential.credential_id,
514
- user_uuid=credential.user_uuid.bytes,
515
- aaguid=credential.aaguid.bytes,
516
- public_key=credential.public_key,
517
- sign_count=credential.sign_count,
518
- created_at=credential.created_at,
519
- last_used=credential.last_used,
520
- last_verified=credential.last_verified,
521
- )
522
- session.add(credential_model)
523
-
524
- async def create_credential_session(
525
- self,
526
- user_uuid: UUID,
527
- credential: Credential,
528
- reset_key: bytes | None,
529
- session_key: bytes,
530
- *,
531
- display_name: str | None = None,
532
- host: str | None = None,
533
- ip: str | None = None,
534
- user_agent: str | None = None,
535
- ) -> None:
536
- """Atomic credential + (optional old session delete) + (optional rename) + new session."""
537
- async with self.session() as session:
538
- # Ensure credential has last_used / last_verified for immediate login semantics
539
- if credential.last_used is None:
540
- credential.last_used = credential.created_at
541
- if credential.last_verified is None:
542
- credential.last_verified = credential.last_used
543
- # Insert credential
544
- session.add(
545
- CredentialModel(
546
- uuid=credential.uuid.bytes,
547
- credential_id=credential.credential_id,
548
- user_uuid=credential.user_uuid.bytes,
549
- aaguid=credential.aaguid.bytes,
550
- public_key=credential.public_key,
551
- sign_count=credential.sign_count,
552
- created_at=credential.created_at,
553
- last_used=credential.last_used,
554
- last_verified=credential.last_verified,
555
- )
556
- )
557
- # Delete old reset token if provided
558
- if reset_key:
559
- await session.execute(
560
- delete(ResetTokenModel).where(ResetTokenModel.key == reset_key)
561
- )
562
- # Optional rename
563
- if display_name:
564
- await session.execute(
565
- update(UserModel)
566
- .where(UserModel.uuid == user_uuid.bytes)
567
- .values(display_name=display_name)
568
- )
569
- # New session
570
- session.add(
571
- SessionModel(
572
- key=session_key,
573
- user_uuid=user_uuid.bytes,
574
- credential_uuid=credential.uuid.bytes,
575
- host=host,
576
- ip=ip,
577
- user_agent=user_agent,
578
- )
579
- )
580
- # Login side-effects: update user analytics (last_seen + visits increment)
581
- await session.execute(
582
- update(UserModel)
583
- .where(UserModel.uuid == user_uuid.bytes)
584
- .values(last_seen=credential.last_used, visits=UserModel.visits + 1)
585
- )
586
-
587
- async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None:
588
- async with self.session() as session:
589
- stmt = (
590
- delete(CredentialModel)
591
- .where(CredentialModel.uuid == uuid.bytes)
592
- .where(CredentialModel.user_uuid == user_uuid.bytes)
593
- )
594
- await session.execute(stmt)
595
-
596
- async def create_session(
597
- self,
598
- user_uuid: UUID,
599
- key: bytes,
600
- credential_uuid: UUID,
601
- host: str,
602
- ip: str,
603
- user_agent: str,
604
- renewed: datetime,
605
- ) -> None:
606
- async with self.session() as session:
607
- session_model = SessionModel(
608
- key=key,
609
- user_uuid=user_uuid.bytes,
610
- credential_uuid=credential_uuid.bytes,
611
- host=host,
612
- ip=ip,
613
- user_agent=user_agent,
614
- renewed=renewed,
615
- )
616
- session.add(session_model)
617
-
618
- async def get_session(self, key: bytes) -> Session | None:
619
- async with self.session() as session:
620
- stmt = select(SessionModel).where(SessionModel.key == key)
621
- result = await session.execute(stmt)
622
- session_model = result.scalar_one_or_none()
623
-
624
- if session_model:
625
- return session_model.as_dataclass()
626
- return None
627
-
628
- async def delete_session(self, key: bytes) -> None:
629
- async with self.session() as session:
630
- await session.execute(delete(SessionModel).where(SessionModel.key == key))
631
-
632
- async def delete_sessions_for_user(self, user_uuid: UUID) -> None:
633
- async with self.session() as session:
634
- await session.execute(
635
- delete(SessionModel).where(SessionModel.user_uuid == user_uuid.bytes)
636
- )
637
-
638
- async def create_reset_token(
639
- self,
640
- user_uuid: UUID,
641
- key: bytes,
642
- expiry: datetime,
643
- token_type: str,
644
- ) -> None:
645
- async with self.session() as session:
646
- model = ResetTokenModel(
647
- key=key,
648
- user_uuid=user_uuid.bytes,
649
- token_type=token_type,
650
- expiry=expiry,
651
- )
652
- session.add(model)
653
-
654
- async def get_reset_token(self, key: bytes) -> ResetToken | None:
655
- async with self.session() as session:
656
- stmt = select(ResetTokenModel).where(ResetTokenModel.key == key)
657
- result = await session.execute(stmt)
658
- model = result.scalar_one_or_none()
659
- return model.as_dataclass() if model else None
660
-
661
- async def delete_reset_token(self, key: bytes) -> None:
662
- async with self.session() as session:
663
- await session.execute(
664
- delete(ResetTokenModel).where(ResetTokenModel.key == key)
665
- )
666
-
667
- async def update_session(
668
- self,
669
- key: bytes,
670
- *,
671
- ip: str,
672
- user_agent: str,
673
- renewed: datetime,
674
- ) -> Session | None:
675
- async with self.session() as session:
676
- model = await session.get(SessionModel, key)
677
- if not model:
678
- return None
679
- model.ip = ip
680
- model.user_agent = user_agent
681
- model.renewed = renewed
682
- await session.flush()
683
- return model.as_dataclass()
684
-
685
- async def set_session_host(self, key: bytes, host: str) -> None:
686
- async with self.session() as session:
687
- model = await session.get(SessionModel, key)
688
- if model and model.host is None:
689
- model.host = host
690
- await session.flush()
691
-
692
- async def list_sessions_for_user(self, user_uuid: UUID) -> list[Session]:
693
- async with self.session() as session:
694
- stmt = (
695
- select(SessionModel)
696
- .where(SessionModel.user_uuid == user_uuid.bytes)
697
- .order_by(SessionModel.renewed.desc())
698
- )
699
- result = await session.execute(stmt)
700
- session_models = [
701
- model
702
- for model in result.scalars().all()
703
- if model.key.startswith(b"sess")
704
- ]
705
- return [model.as_dataclass() for model in session_models]
706
-
707
- # Organization operations
708
- async def create_organization(self, org: Org) -> None:
709
- async with self.session() as session:
710
- org_model = OrgModel(
711
- uuid=org.uuid.bytes,
712
- display_name=org.display_name,
713
- )
714
- session.add(org_model)
715
- # Persist any explicitly provided org grantable permissions
716
- if org.permissions:
717
- for perm_id in set(org.permissions):
718
- session.add(
719
- OrgPermission(org_uuid=org.uuid.bytes, permission_id=perm_id)
720
- )
721
-
722
- # Automatically create an organization admin permission if not present.
723
- auto_perm_id = f"auth:org:{org.uuid}"
724
- # Only create if it does not already exist (in case caller passed it)
725
- existing_perm = await session.execute(
726
- select(PermissionModel).where(PermissionModel.id == auto_perm_id)
727
- )
728
- if not existing_perm.scalar_one_or_none():
729
- session.add(
730
- PermissionModel(
731
- id=auto_perm_id,
732
- display_name=f"{org.display_name} Admin",
733
- )
734
- )
735
- # Ensure org is allowed to grant its own admin permission (insert if missing)
736
- existing_org_perm = await session.execute(
737
- select(OrgPermission).where(
738
- OrgPermission.org_uuid == org.uuid.bytes,
739
- OrgPermission.permission_id == auto_perm_id,
740
- )
741
- )
742
- if not existing_org_perm.scalar_one_or_none():
743
- session.add(
744
- OrgPermission(org_uuid=org.uuid.bytes, permission_id=auto_perm_id)
745
- )
746
- # Reflect the automatically added permission in the dataclass instance
747
- if auto_perm_id not in org.permissions:
748
- org.permissions.append(auto_perm_id)
749
-
750
- async def get_organization(self, org_id: str) -> Org:
751
- async with self.session() as session:
752
- # Convert string ID to UUID bytes for lookup
753
- org_uuid = UUID(org_id)
754
- stmt = select(OrgModel).where(OrgModel.uuid == org_uuid.bytes)
755
- result = await session.execute(stmt)
756
- org_model = result.scalar_one_or_none()
757
-
758
- if not org_model:
759
- raise ValueError("Organization not found")
760
-
761
- # Build Org with permissions and roles
762
- org_dc = org_model.as_dataclass()
763
-
764
- # Load org permission IDs
765
- perm_stmt = select(OrgPermission.permission_id).where(
766
- OrgPermission.org_uuid == org_uuid.bytes
767
- )
768
- perm_result = await session.execute(perm_stmt)
769
- org_dc.permissions = [row[0] for row in perm_result.fetchall()]
770
-
771
- # Load roles for org
772
- roles_stmt = select(RoleModel).where(RoleModel.org_uuid == org_uuid.bytes)
773
- roles_result = await session.execute(roles_stmt)
774
- roles_models = roles_result.scalars().all()
775
- roles: list[Role] = []
776
- if roles_models:
777
- # For each role, load permission IDs
778
- for r_model in roles_models:
779
- r_dc = r_model.as_dataclass()
780
- r_perm_stmt = select(RolePermission.permission_id).where(
781
- RolePermission.role_uuid == r_model.uuid
782
- )
783
- r_perm_result = await session.execute(r_perm_stmt)
784
- r_dc.permissions = [row[0] for row in r_perm_result.fetchall()]
785
- roles.append(r_dc)
786
- org_dc.roles = roles
787
-
788
- return org_dc
789
-
790
- async def list_organizations(self) -> list[Org]:
791
- async with self.session() as session:
792
- # Load all orgs
793
- orgs_result = await session.execute(select(OrgModel))
794
- org_models = orgs_result.scalars().all()
795
- if not org_models:
796
- return []
797
-
798
- # Preload org permissions mapping
799
- org_perms_result = await session.execute(select(OrgPermission))
800
- org_perms = org_perms_result.scalars().all()
801
- perms_by_org: dict[bytes, list[str]] = {}
802
- for op in org_perms:
803
- perms_by_org.setdefault(op.org_uuid, []).append(op.permission_id)
804
-
805
- # Preload roles
806
- roles_result = await session.execute(select(RoleModel))
807
- role_models = roles_result.scalars().all()
808
-
809
- # Preload role permissions mapping
810
- rp_result = await session.execute(select(RolePermission))
811
- rps = rp_result.scalars().all()
812
- perms_by_role: dict[bytes, list[str]] = {}
813
- for rp in rps:
814
- perms_by_role.setdefault(rp.role_uuid, []).append(rp.permission_id)
815
-
816
- # Build org dataclasses with roles and permission IDs
817
- roles_by_org: dict[bytes, list[Role]] = {}
818
- for rm in role_models:
819
- r_dc = rm.as_dataclass()
820
- r_dc.permissions = perms_by_role.get(rm.uuid, [])
821
- roles_by_org.setdefault(rm.org_uuid, []).append(r_dc)
822
-
823
- orgs: list[Org] = []
824
- for om in org_models:
825
- o_dc = om.as_dataclass()
826
- o_dc.permissions = perms_by_org.get(om.uuid, [])
827
- o_dc.roles = roles_by_org.get(om.uuid, [])
828
- orgs.append(o_dc)
829
-
830
- return orgs
831
-
832
- async def update_organization(self, org: Org) -> None:
833
- async with self.session() as session:
834
- stmt = (
835
- update(OrgModel)
836
- .where(OrgModel.uuid == org.uuid.bytes)
837
- .values(display_name=org.display_name)
838
- )
839
- await session.execute(stmt)
840
- # Synchronize org permissions join table to match org.permissions
841
- # Delete existing rows for this org
842
- await session.execute(
843
- delete(OrgPermission).where(OrgPermission.org_uuid == org.uuid.bytes)
844
- )
845
- # Insert new rows
846
- if org.permissions:
847
- for perm_id in org.permissions:
848
- await session.merge(
849
- OrgPermission(org_uuid=org.uuid.bytes, permission_id=perm_id)
850
- )
851
-
852
- async def delete_organization(self, org_uuid: UUID) -> None:
853
- async with self.session() as session:
854
- # Convert string ID to UUID bytes for lookup
855
- stmt = delete(OrgModel).where(OrgModel.uuid == org_uuid.bytes)
856
- await session.execute(stmt)
857
-
858
- async def add_user_to_organization(
859
- self, user_uuid: UUID, org_id: str, role: str
860
- ) -> None:
861
- async with self.session() as session:
862
- org_uuid = UUID(org_id)
863
- # Get user and organization models
864
- user_stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes)
865
- user_result = await session.execute(user_stmt)
866
- user_model = user_result.scalar_one_or_none()
867
-
868
- # Convert string ID to UUID bytes for lookup
869
- org_stmt = select(OrgModel).where(OrgModel.uuid == org_uuid.bytes)
870
- org_result = await session.execute(org_stmt)
871
- org_model = org_result.scalar_one_or_none()
872
-
873
- if not user_model:
874
- raise ValueError("User not found")
875
- if not org_model:
876
- raise ValueError("Organization not found")
877
-
878
- # Find the role within this organization by display_name
879
- role_stmt = select(RoleModel).where(
880
- RoleModel.org_uuid == org_uuid.bytes,
881
- RoleModel.display_name == role,
882
- )
883
- role_result = await session.execute(role_stmt)
884
- role_model = role_result.scalar_one_or_none()
885
- if not role_model:
886
- raise ValueError("Role not found in organization")
887
-
888
- # Update the user's role assignment
889
- stmt = (
890
- update(UserModel)
891
- .where(UserModel.uuid == user_uuid.bytes)
892
- .values(role_uuid=role_model.uuid)
893
- )
894
- await session.execute(stmt)
895
-
896
- async def transfer_user_to_organization(
897
- self, user_uuid: UUID, new_org_id: str, new_role: str | None = None
898
- ) -> None:
899
- # Users are members of an org that never changes after creation.
900
- # Disallow transfers across organizations to enforce invariant.
901
- raise ValueError("Users cannot be transferred to a different organization")
902
-
903
- async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str]:
904
- async with self.session() as session:
905
- stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes)
906
- result = await session.execute(stmt)
907
- user_model = result.scalar_one_or_none()
908
-
909
- if not user_model:
910
- raise ValueError("User not found")
911
-
912
- # Find user's role to get org
913
- role_stmt = select(RoleModel).where(RoleModel.uuid == user_model.role_uuid)
914
- role_result = await session.execute(role_stmt)
915
- role_model = role_result.scalar_one()
916
-
917
- # Fetch the organization details
918
- org_stmt = select(OrgModel).where(OrgModel.uuid == role_model.org_uuid)
919
- org_result = await session.execute(org_stmt)
920
- org_model = org_result.scalar_one()
921
-
922
- # Convert UUID bytes back to string for the interface
923
- return org_model.as_dataclass(), role_model.display_name
924
-
925
- async def get_organization_users(self, org_id: str) -> list[tuple[User, str]]:
926
- async with self.session() as session:
927
- org_uuid = UUID(org_id)
928
- # Join users with roles to filter by org and return role names
929
- stmt = (
930
- select(UserModel, RoleModel.display_name)
931
- .join(RoleModel, UserModel.role_uuid == RoleModel.uuid)
932
- .where(RoleModel.org_uuid == org_uuid.bytes)
933
- )
934
- result = await session.execute(stmt)
935
- rows = result.fetchall()
936
- return [(u.as_dataclass(), role_name) for (u, role_name) in rows]
937
-
938
- async def get_user_role_in_organization(
939
- self, user_uuid: UUID, org_id: str
940
- ) -> str | None:
941
- """Get a user's role in a specific organization."""
942
- async with self.session() as session:
943
- # Convert string ID to UUID bytes for lookup
944
- org_uuid = UUID(org_id)
945
- stmt = (
946
- select(RoleModel.display_name)
947
- .select_from(UserModel)
948
- .join(RoleModel, UserModel.role_uuid == RoleModel.uuid)
949
- .where(
950
- UserModel.uuid == user_uuid.bytes,
951
- RoleModel.org_uuid == org_uuid.bytes,
952
- )
953
- )
954
- result = await session.execute(stmt)
955
- return result.scalar_one_or_none()
956
-
957
- async def update_user_role_in_organization(
958
- self, user_uuid: UUID, new_role: str
959
- ) -> None:
960
- """Update a user's role in their organization."""
961
- async with self.session() as session:
962
- # Find user's current org via their role
963
- user_stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes)
964
- user_result = await session.execute(user_stmt)
965
- user_model = user_result.scalar_one_or_none()
966
- if not user_model:
967
- raise ValueError("User not found")
968
-
969
- current_role_stmt = select(RoleModel).where(
970
- RoleModel.uuid == user_model.role_uuid
971
- )
972
- current_role_result = await session.execute(current_role_stmt)
973
- current_role = current_role_result.scalar_one()
974
-
975
- # Find the new role within the same organization
976
- role_stmt = select(RoleModel).where(
977
- RoleModel.org_uuid == current_role.org_uuid,
978
- RoleModel.display_name == new_role,
979
- )
980
- role_result = await session.execute(role_stmt)
981
- role_model = role_result.scalar_one_or_none()
982
- if not role_model:
983
- raise ValueError("Role not found in user's organization")
984
-
985
- stmt = (
986
- update(UserModel)
987
- .where(UserModel.uuid == user_uuid.bytes)
988
- .values(role_uuid=role_model.uuid)
989
- )
990
- await session.execute(stmt)
991
-
992
- # Permission operations
993
- async def create_permission(self, permission: Permission) -> None:
994
- async with self.session() as session:
995
- permission_model = PermissionModel(
996
- id=permission.id,
997
- display_name=permission.display_name,
998
- )
999
- session.add(permission_model)
1000
-
1001
- async def get_permission(self, permission_id: str) -> Permission:
1002
- async with self.session() as session:
1003
- stmt = select(PermissionModel).where(PermissionModel.id == permission_id)
1004
- result = await session.execute(stmt)
1005
- permission_model = result.scalar_one_or_none()
1006
-
1007
- if permission_model:
1008
- return Permission(
1009
- id=permission_model.id,
1010
- display_name=permission_model.display_name,
1011
- )
1012
- raise ValueError("Permission not found")
1013
-
1014
- async def update_permission(self, permission: Permission) -> None:
1015
- async with self.session() as session:
1016
- stmt = (
1017
- update(PermissionModel)
1018
- .where(PermissionModel.id == permission.id)
1019
- .values(display_name=permission.display_name)
1020
- )
1021
- await session.execute(stmt)
1022
-
1023
- async def rename_permission(
1024
- self, old_id: str, new_id: str, display_name: str
1025
- ) -> None:
1026
- """Rename a permission's primary key and update referencing tables.
1027
-
1028
- Approach: insert new row (if id changes), update FKs, delete old row.
1029
- Wrapped in a transaction; will raise on conflict.
1030
- """
1031
- if old_id == new_id:
1032
- # Just update display name
1033
- async with self.session() as session:
1034
- stmt = (
1035
- update(PermissionModel)
1036
- .where(PermissionModel.id == old_id)
1037
- .values(display_name=display_name)
1038
- )
1039
- await session.execute(stmt)
1040
- return
1041
- async with self.session() as session:
1042
- # Ensure old exists
1043
- existing_old = await session.execute(
1044
- select(PermissionModel).where(PermissionModel.id == old_id)
1045
- )
1046
- if not existing_old.scalar_one_or_none():
1047
- raise ValueError("Original permission not found")
1048
-
1049
- # Check new not taken
1050
- existing_new = await session.execute(
1051
- select(PermissionModel).where(PermissionModel.id == new_id)
1052
- )
1053
- if existing_new.scalar_one_or_none():
1054
- raise ValueError("New permission id already exists")
1055
-
1056
- # Create new permission row first
1057
- session.add(PermissionModel(id=new_id, display_name=display_name))
1058
- await session.flush()
1059
-
1060
- # Update org_permissions
1061
- await session.execute(
1062
- update(OrgPermission)
1063
- .where(OrgPermission.permission_id == old_id)
1064
- .values(permission_id=new_id)
1065
- )
1066
- await session.flush()
1067
- # Update role_permissions
1068
- await session.execute(
1069
- update(RolePermission)
1070
- .where(RolePermission.permission_id == old_id)
1071
- .values(permission_id=new_id)
1072
- )
1073
- await session.flush()
1074
- # Delete old permission row
1075
- await session.execute(
1076
- delete(PermissionModel).where(PermissionModel.id == old_id)
1077
- )
1078
- await session.flush()
1079
-
1080
- async def delete_permission(self, permission_id: str) -> None:
1081
- async with self.session() as session:
1082
- stmt = delete(PermissionModel).where(PermissionModel.id == permission_id)
1083
- await session.execute(stmt)
1084
-
1085
- async def list_permissions(self) -> list[Permission]:
1086
- async with self.session() as session:
1087
- result = await session.execute(select(PermissionModel))
1088
- return [p.as_dataclass() for p in result.scalars().all()]
1089
-
1090
- async def add_permission_to_role(self, role_uuid: UUID, permission_id: str) -> None:
1091
- async with self.session() as session:
1092
- # Ensure role exists
1093
- role_stmt = select(RoleModel).where(RoleModel.uuid == role_uuid.bytes)
1094
- role_result = await session.execute(role_stmt)
1095
- role_model = role_result.scalar_one_or_none()
1096
- if not role_model:
1097
- raise ValueError("Role not found")
1098
-
1099
- # Ensure permission exists
1100
- perm_stmt = select(PermissionModel).where(
1101
- PermissionModel.id == permission_id
1102
- )
1103
- perm_result = await session.execute(perm_stmt)
1104
- if not perm_result.scalar_one_or_none():
1105
- raise ValueError("Permission not found")
1106
-
1107
- session.add(
1108
- RolePermission(role_uuid=role_uuid.bytes, permission_id=permission_id)
1109
- )
1110
-
1111
- async def remove_permission_from_role(
1112
- self, role_uuid: UUID, permission_id: str
1113
- ) -> None:
1114
- async with self.session() as session:
1115
- await session.execute(
1116
- delete(RolePermission)
1117
- .where(RolePermission.role_uuid == role_uuid.bytes)
1118
- .where(RolePermission.permission_id == permission_id)
1119
- )
1120
-
1121
- async def get_role_permissions(self, role_uuid: UUID) -> list[Permission]:
1122
- async with self.session() as session:
1123
- stmt = (
1124
- select(PermissionModel)
1125
- .join(
1126
- RolePermission, PermissionModel.id == RolePermission.permission_id
1127
- )
1128
- .where(RolePermission.role_uuid == role_uuid.bytes)
1129
- )
1130
- result = await session.execute(stmt)
1131
- return [p.as_dataclass() for p in result.scalars().all()]
1132
-
1133
- async def get_permission_roles(self, permission_id: str) -> list[Role]:
1134
- async with self.session() as session:
1135
- stmt = (
1136
- select(RoleModel)
1137
- .join(RolePermission, RoleModel.uuid == RolePermission.role_uuid)
1138
- .where(RolePermission.permission_id == permission_id)
1139
- )
1140
- result = await session.execute(stmt)
1141
- return [r.as_dataclass() for r in result.scalars().all()]
1142
-
1143
- async def update_role(self, role: Role) -> None:
1144
- async with self.session() as session:
1145
- # Update role display_name
1146
- await session.execute(
1147
- update(RoleModel)
1148
- .where(RoleModel.uuid == role.uuid.bytes)
1149
- .values(display_name=role.display_name)
1150
- )
1151
- # Sync role permissions: delete all then insert current set
1152
- await session.execute(
1153
- delete(RolePermission).where(
1154
- RolePermission.role_uuid == role.uuid.bytes
1155
- )
1156
- )
1157
- if role.permissions:
1158
- for perm_id in set(role.permissions):
1159
- await session.execute(
1160
- insert(RolePermission).values(
1161
- role_uuid=role.uuid.bytes, permission_id=perm_id
1162
- )
1163
- )
1164
-
1165
- async def delete_role(self, role_uuid: UUID) -> None:
1166
- async with self.session() as session:
1167
- # Prevent deleting a role that still has users
1168
- # Quick existence check for users assigned to the role
1169
- existing_user = await session.execute(
1170
- select(UserModel.uuid).where(UserModel.role_uuid == role_uuid.bytes)
1171
- )
1172
- if existing_user.first() is not None:
1173
- raise ValueError("Cannot delete role with assigned users")
1174
-
1175
- await session.execute(
1176
- delete(RoleModel).where(RoleModel.uuid == role_uuid.bytes)
1177
- )
1178
-
1179
- async def get_role(self, role_uuid: UUID) -> Role:
1180
- async with self.session() as session:
1181
- result = await session.execute(
1182
- select(RoleModel).where(RoleModel.uuid == role_uuid.bytes)
1183
- )
1184
- role_model = result.scalar_one_or_none()
1185
- if not role_model:
1186
- raise ValueError("Role not found")
1187
- r_dc = role_model.as_dataclass()
1188
- perms_result = await session.execute(
1189
- select(RolePermission.permission_id).where(
1190
- RolePermission.role_uuid == role_uuid.bytes
1191
- )
1192
- )
1193
- r_dc.permissions = [row[0] for row in perms_result.fetchall()]
1194
- return r_dc
1195
-
1196
- async def get_roles_by_organization(self, org_id: str) -> list[Role]:
1197
- async with self.session() as session:
1198
- org_uuid = UUID(org_id)
1199
- result = await session.execute(
1200
- select(RoleModel).where(RoleModel.org_uuid == org_uuid.bytes)
1201
- )
1202
- role_models = result.scalars().all()
1203
- roles: list[Role] = []
1204
- for rm in role_models:
1205
- r_dc = rm.as_dataclass()
1206
- perms_result = await session.execute(
1207
- select(RolePermission.permission_id).where(
1208
- RolePermission.role_uuid == rm.uuid
1209
- )
1210
- )
1211
- r_dc.permissions = [row[0] for row in perms_result.fetchall()]
1212
- roles.append(r_dc)
1213
- return roles
1214
-
1215
- async def add_permission_to_organization(
1216
- self, org_id: str, permission_id: str
1217
- ) -> None:
1218
- async with self.session() as session:
1219
- # Get organization and permission models
1220
- org_uuid = UUID(org_id)
1221
- org_stmt = select(OrgModel).where(OrgModel.uuid == org_uuid.bytes)
1222
- org_result = await session.execute(org_stmt)
1223
- org_model = org_result.scalar_one_or_none()
1224
-
1225
- permission_stmt = select(PermissionModel).where(
1226
- PermissionModel.id == permission_id
1227
- )
1228
- permission_result = await session.execute(permission_stmt)
1229
- permission_model = permission_result.scalar_one_or_none()
1230
-
1231
- if not org_model:
1232
- raise ValueError("Organization not found")
1233
- if not permission_model:
1234
- raise ValueError("Permission not found")
1235
-
1236
- # Create the org-permission relationship
1237
- org_permission = OrgPermission(
1238
- org_uuid=org_uuid.bytes, permission_id=permission_id
1239
- )
1240
- session.add(org_permission)
1241
-
1242
- async def remove_permission_from_organization(
1243
- self, org_id: str, permission_id: str
1244
- ) -> None:
1245
- async with self.session() as session:
1246
- # Convert string ID to UUID bytes for lookup
1247
- org_uuid = UUID(org_id)
1248
- # Delete the org-permission relationship
1249
- stmt = delete(OrgPermission).where(
1250
- OrgPermission.org_uuid == org_uuid.bytes,
1251
- OrgPermission.permission_id == permission_id,
1252
- )
1253
- await session.execute(stmt)
1254
-
1255
- async def get_organization_permissions(self, org_id: str) -> list[Permission]:
1256
- async with self.session() as session:
1257
- # Convert string ID to UUID bytes for lookup
1258
- org_uuid = UUID(org_id)
1259
- stmt = select(OrgPermission).where(OrgPermission.org_uuid == org_uuid.bytes)
1260
- result = await session.execute(stmt)
1261
- org_permission_models = result.scalars().all()
1262
-
1263
- # Fetch the permission details for each org-permission relationship
1264
- permissions = []
1265
- for org_permission in org_permission_models:
1266
- permission_stmt = select(PermissionModel).where(
1267
- PermissionModel.id == org_permission.permission_id
1268
- )
1269
- permission_result = await session.execute(permission_stmt)
1270
- permission_model = permission_result.scalar_one()
1271
-
1272
- permission = Permission(
1273
- id=permission_model.id,
1274
- display_name=permission_model.display_name,
1275
- )
1276
- permissions.append(permission)
1277
-
1278
- return permissions
1279
-
1280
- async def get_permission_organizations(self, permission_id: str) -> list[Org]:
1281
- async with self.session() as session:
1282
- stmt = select(OrgPermission).where(
1283
- OrgPermission.permission_id == permission_id
1284
- )
1285
- result = await session.execute(stmt)
1286
- org_permission_models = result.scalars().all()
1287
-
1288
- # Fetch the organization details for each org-permission relationship
1289
- organizations = []
1290
- for org_permission in org_permission_models:
1291
- org_stmt = select(OrgModel).where(
1292
- OrgModel.uuid == org_permission.org_uuid
1293
- )
1294
- org_result = await session.execute(org_stmt)
1295
- org_model = org_result.scalar_one()
1296
- organizations.append(org_model.as_dataclass())
1297
-
1298
- return organizations
1299
-
1300
- async def cleanup(self) -> None:
1301
- async with self.session() as session:
1302
- current_time = datetime.now(timezone.utc)
1303
- session_threshold = current_time - SESSION_LIFETIME
1304
- await session.execute(
1305
- delete(SessionModel).where(SessionModel.renewed < session_threshold)
1306
- )
1307
- await session.execute(
1308
- delete(ResetTokenModel).where(ResetTokenModel.expiry < current_time)
1309
- )
1310
-
1311
- async def get_session_context(
1312
- self, session_key: bytes, host: str | None = None
1313
- ) -> SessionContext | None:
1314
- """Get complete session context including user, organization, role, and permissions.
1315
-
1316
- Uses efficient JOINs to retrieve all related data in a single database query.
1317
- """
1318
- async with self.session() as session:
1319
- # Build a query that joins sessions, users, roles, organizations, credentials and role_permissions
1320
- stmt = (
1321
- select(
1322
- SessionModel,
1323
- UserModel,
1324
- RoleModel,
1325
- OrgModel,
1326
- CredentialModel,
1327
- PermissionModel,
1328
- )
1329
- .select_from(SessionModel)
1330
- .join(UserModel, SessionModel.user_uuid == UserModel.uuid)
1331
- .join(RoleModel, UserModel.role_uuid == RoleModel.uuid)
1332
- .join(OrgModel, RoleModel.org_uuid == OrgModel.uuid)
1333
- .outerjoin(
1334
- CredentialModel,
1335
- SessionModel.credential_uuid == CredentialModel.uuid,
1336
- )
1337
- .outerjoin(RolePermission, RoleModel.uuid == RolePermission.role_uuid)
1338
- .outerjoin(
1339
- PermissionModel, RolePermission.permission_id == PermissionModel.id
1340
- )
1341
- .where(SessionModel.key == session_key)
1342
- )
1343
-
1344
- result = await session.execute(stmt)
1345
- rows = result.fetchall()
1346
-
1347
- if not rows:
1348
- return None
1349
-
1350
- # Extract the first row to get session and user data
1351
- first_row = rows[0]
1352
- session_model, user_model, role_model, org_model, credential_model, _ = (
1353
- first_row
1354
- )
1355
-
1356
- # Create the session object
1357
- if host is not None:
1358
- if session_model.host is None:
1359
- await session.execute(
1360
- update(SessionModel)
1361
- .where(SessionModel.key == session_key)
1362
- .values(host=host)
1363
- )
1364
- session_model.host = host
1365
- elif session_model.host != host:
1366
- return None
1367
-
1368
- session_obj = session_model.as_dataclass()
1369
-
1370
- # Create the user object
1371
- user_obj = user_model.as_dataclass()
1372
-
1373
- # Create organization object (fill permissions later if needed)
1374
- organization = Org(UUID(bytes=org_model.uuid), org_model.display_name)
1375
-
1376
- # Create role object
1377
- role = Role(
1378
- uuid=UUID(bytes=role_model.uuid),
1379
- org_uuid=UUID(bytes=role_model.org_uuid),
1380
- display_name=role_model.display_name,
1381
- )
1382
-
1383
- # Create credential object if available
1384
- credential_obj = (
1385
- credential_model.as_dataclass() if credential_model else None
1386
- )
1387
-
1388
- # Collect all unique permissions for the role
1389
- permissions = []
1390
- seen_permission_ids = set()
1391
- for row in rows:
1392
- _, _, _, _, _, permission_model = row
1393
- if permission_model and permission_model.id not in seen_permission_ids:
1394
- permissions.append(
1395
- Permission(
1396
- id=permission_model.id,
1397
- display_name=permission_model.display_name,
1398
- )
1399
- )
1400
- seen_permission_ids.add(permission_model.id)
1401
-
1402
- # Attach permission IDs to role
1403
- role.permissions = list(seen_permission_ids)
1404
-
1405
- # Load org permission IDs as well
1406
- org_perm_stmt = select(OrgPermission.permission_id).where(
1407
- OrgPermission.org_uuid == org_model.uuid
1408
- )
1409
- org_perm_result = await session.execute(org_perm_stmt)
1410
- organization.permissions = [row[0] for row in org_perm_result.fetchall()]
1411
-
1412
- # Filter effective permissions: only include permissions that the org can grant
1413
- effective_permissions = [
1414
- p for p in permissions if p.id in organization.permissions
1415
- ]
1416
-
1417
- return SessionContext(
1418
- session=session_obj,
1419
- user=user_obj,
1420
- org=organization,
1421
- role=role,
1422
- credential=credential_obj,
1423
- permissions=effective_permissions if effective_permissions else None,
1424
- )