paskia 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. paskia/_version.py +2 -2
  2. paskia/aaguid/__init__.py +5 -4
  3. paskia/authsession.py +4 -19
  4. paskia/db/__init__.py +2 -4
  5. paskia/db/background.py +3 -3
  6. paskia/db/jsonl.py +99 -111
  7. paskia/db/logging.py +318 -0
  8. paskia/db/migrations.py +19 -20
  9. paskia/db/operations.py +107 -196
  10. paskia/db/structs.py +236 -46
  11. paskia/fastapi/__main__.py +13 -6
  12. paskia/fastapi/admin.py +72 -195
  13. paskia/fastapi/api.py +56 -58
  14. paskia/fastapi/authz.py +3 -8
  15. paskia/fastapi/logging.py +261 -0
  16. paskia/fastapi/mainapp.py +14 -3
  17. paskia/fastapi/remote.py +11 -37
  18. paskia/fastapi/reset.py +0 -2
  19. paskia/fastapi/response.py +22 -0
  20. paskia/fastapi/user.py +7 -7
  21. paskia/fastapi/ws.py +14 -37
  22. paskia/fastapi/wschat.py +55 -2
  23. paskia/fastapi/wsutil.py +10 -2
  24. paskia/frontend-build/auth/admin/index.html +6 -6
  25. paskia/frontend-build/auth/assets/AccessDenied-C29NZI95.css +1 -0
  26. paskia/frontend-build/auth/assets/AccessDenied-DAdzg_MJ.js +12 -0
  27. paskia/frontend-build/auth/assets/{RestrictedAuth-CvR33_Z0.css → RestrictedAuth-BOdNrlQB.css} +1 -1
  28. paskia/frontend-build/auth/assets/{RestrictedAuth-DsJXicIw.js → RestrictedAuth-BSusdAfp.js} +1 -1
  29. paskia/frontend-build/auth/assets/_plugin-vue_export-helper-D2l53SUz.js +49 -0
  30. paskia/frontend-build/auth/assets/_plugin-vue_export-helper-DYJ24FZK.css +1 -0
  31. paskia/frontend-build/auth/assets/admin-BeFvGyD6.js +1 -0
  32. paskia/frontend-build/auth/assets/{admin-DzzjSg72.css → admin-CmNtuH3s.css} +1 -1
  33. paskia/frontend-build/auth/assets/{auth-C7k64Wad.css → auth-BKq4T2K2.css} +1 -1
  34. paskia/frontend-build/auth/assets/auth-DvHf8hgy.js +1 -0
  35. paskia/frontend-build/auth/assets/{forward-DmqVHZ7e.js → forward-C86Jm_Uq.js} +1 -1
  36. paskia/frontend-build/auth/assets/reset-B8PlNXuP.css +1 -0
  37. paskia/frontend-build/auth/assets/reset-D71FG0VL.js +1 -0
  38. paskia/frontend-build/auth/assets/{restricted-D3AJx3_6.js → restricted-CW0drE_k.js} +1 -1
  39. paskia/frontend-build/auth/index.html +6 -6
  40. paskia/frontend-build/auth/restricted/index.html +5 -5
  41. paskia/frontend-build/int/forward/index.html +5 -5
  42. paskia/frontend-build/int/reset/index.html +4 -4
  43. paskia/migrate/__init__.py +9 -9
  44. paskia/migrate/sql.py +26 -19
  45. paskia/remoteauth.py +6 -6
  46. {paskia-0.9.0.dist-info → paskia-0.10.0.dist-info}/METADATA +1 -1
  47. paskia-0.10.0.dist-info/RECORD +60 -0
  48. paskia/frontend-build/auth/assets/AccessDenied-DPkUS8LZ.css +0 -1
  49. paskia/frontend-build/auth/assets/AccessDenied-Fmeb6EtF.js +0 -8
  50. paskia/frontend-build/auth/assets/_plugin-vue_export-helper-BTzJAQlS.css +0 -1
  51. paskia/frontend-build/auth/assets/_plugin-vue_export-helper-nhjnO_bd.js +0 -2
  52. paskia/frontend-build/auth/assets/admin-CPE1pLMm.js +0 -1
  53. paskia/frontend-build/auth/assets/auth-YIZvPlW_.js +0 -1
  54. paskia/frontend-build/auth/assets/reset-Chtv69AT.css +0 -1
  55. paskia/frontend-build/auth/assets/reset-s20PATTN.js +0 -1
  56. paskia-0.9.0.dist-info/RECORD +0 -57
  57. {paskia-0.9.0.dist-info → paskia-0.10.0.dist-info}/WHEEL +0 -0
  58. {paskia-0.9.0.dist-info → paskia-0.10.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,261 @@
1
+ """Custom access logging middleware for FastAPI/Uvicorn."""
2
+
3
+ import logging
4
+ import sys
5
+ import time
6
+ from ipaddress import IPv6Address
7
+ from typing import TYPE_CHECKING
8
+
9
+ from starlette.middleware.base import BaseHTTPMiddleware
10
+
11
+ if TYPE_CHECKING:
12
+ from paskia.db.structs import SessionContext
13
+ from starlette.requests import Request
14
+ from starlette.responses import Response
15
+
16
+ logger = logging.getLogger("paskia.access")
17
+
18
+ _RESET = "\033[0m"
19
+ _STATUS_INFO = "\033[32m" # 1xx (green)
20
+ _STATUS_OK = "\033[1;92m" # 2xx (bright green)
21
+ _STATUS_REDIRECT = "\033[32m" # 3xx (green)
22
+ _STATUS_CLIENT_ERR = "\033[0;31m" # 4xx (red)
23
+ _STATUS_SERVER_ERR = "\033[1;91m" # 5xx (bold bright red)
24
+ _METHOD_READ = "\033[0;34m" # GET, HEAD, OPTIONS (blue)
25
+ _METHOD_WRITE = "\033[1;94m" # POST, PUT, DELETE, PATCH (bold bright blue)
26
+ _HOST = "\033[38;5;242m" # hostname (dark grey)
27
+ _PATH = "\033[38;5;250m" # path (white)
28
+ _TIMING = "\033[38;5;242m" # timing/devmode (dark grey)
29
+ _WS_OPEN = "\033[1;93m" # WebSocket connect (bold bright yellow)
30
+ _WS_CLOSE = "\033[33m" # WebSocket disconnect (yellow)
31
+ _WS_STATUS = "\033[38;5;242m" # WebSocket close status (dark grey)
32
+ _AUTHZ_DENIED = "\033[0;31m" # Permission denied (red)
33
+ _AUTHZ_USER = "\033[1;34m" # User info (light blue)
34
+ _AUTHZ_ORG = "\033[34m" # User info (blue)
35
+ _AUTHZ_NEEDS = "\033[1;38;5;231m" # Needs (brightest white)
36
+ _AUTHZ_MISSING = "\033[1;31m" # Missing scope (bold red)
37
+ _AUTHZ_GRANTED = "\033[0;32m" # Granted scope (green)
38
+
39
+
40
+ def format_ipv6_network(ip: str) -> str:
41
+ """Format IPv6 address to show only network part (first 64 bits)."""
42
+ try:
43
+ addr = IPv6Address(ip)
44
+ # Get the integer representation and mask to first 64 bits
45
+ network_int = int(addr) >> 64
46
+ # Format as IPv6 with trailing ::
47
+ # Split into 4 groups of 16 bits
48
+ groups = []
49
+ for _ in range(4):
50
+ groups.insert(0, format(network_int & 0xFFFF, "x"))
51
+ network_int >>= 16
52
+ # Compress consecutive zero groups
53
+ result = ":".join(groups) + "::"
54
+ # Simplify leading zeros in groups and compress, then strip trailing ::
55
+ return str(IPv6Address(result + "0")).removesuffix("::")
56
+ except Exception:
57
+ return ip
58
+
59
+
60
+ def format_client_ip(ip: str) -> str:
61
+ """Format client IP, compressing IPv6 to network part only."""
62
+ if not ip or ip == "-":
63
+ return "-"
64
+ if ":" in ip:
65
+ return format_ipv6_network(ip)
66
+ return ip
67
+
68
+
69
+ def status_color(status: int) -> str:
70
+ """Return color code based on HTTP status."""
71
+ if status < 200:
72
+ return _STATUS_INFO
73
+ if status < 300:
74
+ return _STATUS_OK
75
+ if status < 400:
76
+ return _STATUS_REDIRECT
77
+ if status < 500:
78
+ return _STATUS_CLIENT_ERR
79
+ return _STATUS_SERVER_ERR
80
+
81
+
82
+ def method_color(method: str) -> str:
83
+ """Return color code based on HTTP method."""
84
+ if method in ("GET", "HEAD", "OPTIONS"):
85
+ return _METHOD_READ
86
+ return _METHOD_WRITE
87
+
88
+
89
+ def format_access_log(
90
+ client: str, status: int, method: str, host: str, path: str, duration_ms: float
91
+ ) -> str:
92
+ """Format access log line with colors and aligned fields."""
93
+ use_color = sys.stderr.isatty()
94
+
95
+ # Format components with fixed widths for alignment
96
+ ip = format_client_ip(client).ljust(19) # IPv6 network max 19 chars
97
+ timing = f"{duration_ms:.0f}ms"
98
+ method_padded = method.ljust(7) # Longest method is OPTIONS (7)
99
+
100
+ if use_color:
101
+ status_str = f"{status_color(status)}{status}{_RESET}"
102
+ timing_str = f"{_TIMING}{timing}{_RESET}"
103
+ method_str = f"{method_color(method)}{method_padded}{_RESET}"
104
+ host_str = f"{_HOST}{host}{_RESET}"
105
+ path_str = f"{_PATH}{path}{_RESET}"
106
+ else:
107
+ status_str = str(status)
108
+ timing_str = timing
109
+ method_str = method_padded
110
+ host_str = host
111
+ path_str = path
112
+
113
+ # Format: "IP STATUS METHOD host path TIMING"
114
+ return f"{ip} {status_str} {method_str} {host_str}{path_str} {timing_str}"
115
+
116
+
117
+ # WebSocket connection counter (mod 100)
118
+ _ws_counter = 0
119
+
120
+
121
+ def _next_ws_id() -> int:
122
+ """Get next WebSocket connection ID (0-99)."""
123
+ global _ws_counter
124
+ ws_id = _ws_counter
125
+ _ws_counter = (_ws_counter + 1) % 100
126
+ return ws_id
127
+
128
+
129
+ def log_ws_open(ws) -> int:
130
+ """Log WebSocket connection open. Returns connection ID for use in close."""
131
+ use_color = sys.stderr.isatty()
132
+ ws_id = _next_ws_id()
133
+
134
+ client = ws.client.host if ws.client else "-"
135
+ host = ws.headers.get("host", "-")
136
+ path = ws.url.path
137
+ origin = ws.headers.get("origin")
138
+
139
+ ip = format_client_ip(client).ljust(19)
140
+ id_str = f"{ws_id:02d}".ljust(7) # Align with method field (7 chars)
141
+
142
+ # Determine if origin should be shown (omit when same as host)
143
+ # Origin header includes scheme (e.g., "https://example.com"), compare host part
144
+ origin_host = origin.split("://", 1)[-1] if origin else None
145
+ show_origin = origin_host and origin_host != host
146
+
147
+ if use_color:
148
+ # 🔌 aligned with status (takes ~2 char width), ID aligned with method
149
+ prefix = f"🔌 {_WS_OPEN}{id_str}{_RESET}"
150
+ host_str = f"{_HOST}{host}{_RESET}"
151
+ path_str = f"{_PATH}{path}{_RESET}"
152
+ origin_str = (
153
+ f" {_RESET}from {_HOST}{origin_host}{_RESET}" if show_origin else ""
154
+ )
155
+ else:
156
+ prefix = f"WS+ {id_str}"
157
+ host_str = host
158
+ path_str = path
159
+ origin_str = f" from {origin_host}" if show_origin else ""
160
+
161
+ logger.info(f"{ip} {prefix} {host_str}{path_str}{origin_str}")
162
+ return ws_id
163
+
164
+
165
+ # WebSocket close codes to human-readable status
166
+ WS_CLOSE_CODES = {
167
+ 1000: "ok",
168
+ 1001: "going away",
169
+ 1002: "protocol error",
170
+ 1003: "unsupported",
171
+ 1005: "no status",
172
+ 1006: "abnormal",
173
+ 1007: "invalid data",
174
+ 1008: "policy violation",
175
+ 1009: "too large",
176
+ 1010: "extension required",
177
+ 1011: "server error",
178
+ 1012: "restarting",
179
+ 1013: "try again",
180
+ 1014: "bad gateway",
181
+ 1015: "tls error",
182
+ }
183
+
184
+
185
+ def log_ws_close(ws_id: int, close_code: int | None, duration: float) -> None:
186
+ """Log WebSocket connection close with duration and status."""
187
+ use_color = sys.stderr.isatty()
188
+
189
+ id_str = f"{ws_id:02d}".ljust(7) # Align with method field (7 chars)
190
+ timing = f"{duration * 1000:.0f}ms"
191
+
192
+ # Convert close code to status text
193
+ if close_code is None:
194
+ status = "closed"
195
+ else:
196
+ status = WS_CLOSE_CODES.get(close_code, f"code {close_code}")
197
+
198
+ if use_color:
199
+ # 🔌 aligned with status, ID aligned with method
200
+ prefix = f"🔌 {_WS_CLOSE}{id_str}{_RESET}"
201
+ status_str = f"{_WS_STATUS}{status}{_RESET}"
202
+ timing_str = f"{_TIMING}{timing}{_RESET}"
203
+ else:
204
+ prefix = f"WS- {id_str}"
205
+ status_str = status
206
+ timing_str = timing
207
+
208
+ logger.info(f"{' ' * 19} {prefix} {status_str} {timing_str}")
209
+
210
+
211
+ def log_permission_denied(
212
+ ctx: "SessionContext", required: list[str], missing: list[str], *, require_all: bool
213
+ ) -> None:
214
+ """Log permission denied with org, role, user and highlighted missing scopes."""
215
+ missing_set = set(missing)
216
+ scopes = " ".join(
217
+ f"{_AUTHZ_MISSING}{s}✗{_RESET}"
218
+ if s in missing_set
219
+ else f"{_AUTHZ_GRANTED}{s}✓{_RESET}"
220
+ for s in required
221
+ )
222
+ n = "" if len(required) == 1 else " all" if require_all else " any"
223
+ logger.warning(
224
+ f"{_AUTHZ_DENIED}Permission denied{_RESET} "
225
+ f"{_AUTHZ_USER}{ctx.user.display_name}{_RESET} "
226
+ f"{_AUTHZ_ORG}({ctx.org.display_name} {ctx.role.display_name}){_RESET} "
227
+ f"{_AUTHZ_NEEDS}needs{n}:{_RESET} {scopes}"
228
+ )
229
+
230
+
231
+ class AccessLogMiddleware(BaseHTTPMiddleware):
232
+ """Middleware that logs HTTP requests with custom format."""
233
+
234
+ async def dispatch(self, request: Request, call_next) -> Response:
235
+ start = time.perf_counter()
236
+ response = await call_next(request)
237
+ duration_ms = (time.perf_counter() - start) * 1000
238
+
239
+ client = request.client.host if request.client else "-"
240
+ host = request.headers.get("host", "-")
241
+ method = request.method
242
+ path = request.url.path
243
+ if request.url.query:
244
+ path = f"{path}?{request.url.query}"
245
+ status = response.status_code
246
+
247
+ line = format_access_log(client, status, method, host, path, duration_ms)
248
+ logger.info(line)
249
+
250
+ return response
251
+
252
+
253
+ def configure_access_logging():
254
+ """Configure the access logger to output to stderr."""
255
+ handler = logging.StreamHandler(sys.stderr)
256
+ handler.setFormatter(logging.Formatter("%(message)s"))
257
+ logger.addHandler(handler)
258
+ logger.setLevel(logging.INFO)
259
+ logger.propagate = False
260
+ # Suppress watchfiles "X changes detected" INFO messages (keep WARNING for reload notification)
261
+ logging.getLogger("watchfiles.main").setLevel(logging.WARNING)
paskia/fastapi/mainapp.py CHANGED
@@ -10,10 +10,18 @@ from fastapi_vue import Frontend
10
10
 
11
11
  from paskia import globals
12
12
  from paskia.db import start_background, stop_background
13
+ from paskia.db.logging import configure_db_logging
13
14
  from paskia.fastapi import admin, api, auth_host, ws
15
+ from paskia.fastapi.logging import AccessLogMiddleware, configure_access_logging
14
16
  from paskia.fastapi.session import AUTH_COOKIE
15
17
  from paskia.util import hostutil, passphrase, vitedev
16
18
 
19
+ # Configure custom logging
20
+ configure_access_logging()
21
+ configure_db_logging()
22
+
23
+ _access_logger = logging.getLogger("paskia.access")
24
+
17
25
  # Vue Frontend static files
18
26
  frontend = Frontend(
19
27
  Path(__file__).parent.parent / "frontend-build",
@@ -48,11 +56,11 @@ async def lifespan(app: FastAPI): # pragma: no cover - startup path
48
56
  # Re-raise to fail fast
49
57
  raise
50
58
 
51
- # Restore info level logging after startup (suppressed during uvicorn init in dev mode)
59
+ # Restore uvicorn info logging (suppressed during startup in dev mode)
60
+ # Keep uvicorn.error at WARNING to suppress WebSocket "connection open/closed" messages
52
61
  if frontend.devmode:
53
62
  logging.getLogger("uvicorn").setLevel(logging.INFO)
54
- logging.getLogger("uvicorn.access").setLevel(logging.INFO)
55
-
63
+ logging.getLogger("uvicorn.error").setLevel(logging.WARNING)
56
64
  await frontend.load()
57
65
  await start_background()
58
66
  yield
@@ -67,6 +75,9 @@ app = FastAPI(
67
75
  openapi_url=None,
68
76
  )
69
77
 
78
+ # Custom access logging (uvicorn's access_log is disabled)
79
+ app.add_middleware(AccessLogMiddleware)
80
+
70
81
  # Apply redirections to auth-host if configured (deny access to restricted endpoints, remove /auth/)
71
82
  app.middleware("http")(auth_host.redirect_middleware)
72
83
 
paskia/fastapi/remote.py CHANGED
@@ -17,10 +17,10 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
17
17
 
18
18
  from paskia import db, remoteauth
19
19
  from paskia.authsession import expires
20
- from paskia.fastapi.session import infodict
21
- from paskia.fastapi.wschat import authenticate_chat
20
+ from paskia.fastapi.session import AUTH_COOKIE, infodict
21
+ from paskia.fastapi.wschat import authenticate_and_login
22
22
  from paskia.fastapi.wsutil import validate_origin, websocket_error_handler
23
- from paskia.util import hostutil, passphrase, pow, useragent
23
+ from paskia.util import passphrase, pow, useragent
24
24
 
25
25
  # Create a FastAPI subapp for remote auth WebSocket endpoints
26
26
  app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
@@ -252,7 +252,7 @@ async def websocket_remote_auth_request(ws: WebSocket):
252
252
 
253
253
  @app.websocket("/permit")
254
254
  @websocket_error_handler
255
- async def websocket_remote_auth_permit(ws: WebSocket):
255
+ async def websocket_remote_auth_permit(ws: WebSocket, auth=AUTH_COOKIE):
256
256
  """Complete a remote authentication request using a 3-word pairing code.
257
257
 
258
258
  This endpoint is called from the user's profile on the authenticating device.
@@ -270,7 +270,7 @@ async def websocket_remote_auth_permit(ws: WebSocket):
270
270
  7. Server sends {status: "success", message: "..."}
271
271
  """
272
272
 
273
- origin = validate_origin(ws)
273
+ validate_origin(ws)
274
274
 
275
275
  if remoteauth.instance is None:
276
276
  raise ValueError("Remote authentication is not available")
@@ -310,56 +310,30 @@ async def websocket_remote_auth_permit(ws: WebSocket):
310
310
 
311
311
  # Handle authenticate request (no PoW needed - already validated during lookup)
312
312
  if msg.get("authenticate") and request is not None:
313
- cred, new_sign_count = await authenticate_chat(ws, origin)
314
-
315
- # Create a session for the REQUESTING device
316
- assert cred.uuid is not None
313
+ ctx = await authenticate_and_login(ws, auth)
317
314
 
318
- session_token = None
315
+ session_token = ctx.session.key
319
316
  reset_token = None
320
317
 
321
318
  if request.action == "register":
322
319
  # For registration, create a reset token for device addition
323
-
324
320
  token_str = passphrase.generate()
325
321
  expiry = expires()
326
322
  db.create_reset_token(
327
- user_uuid=cred.user,
323
+ user_uuid=ctx.user.uuid,
328
324
  passphrase=token_str,
329
325
  expiry=expiry,
330
326
  token_type="device addition",
327
+ user=str(ctx.user.uuid),
331
328
  )
332
329
  reset_token = token_str
333
- # Also create a session so the device is logged in
334
- normalized_host = hostutil.normalize_host(request.host)
335
- session_token = db.login(
336
- user_uuid=cred.user,
337
- credential_uuid=cred.uuid,
338
- sign_count=new_sign_count,
339
- host=normalized_host,
340
- ip=request.ip,
341
- user_agent=request.user_agent,
342
- expiry=expires(),
343
- )
344
- else:
345
- # Default login action
346
-
347
- normalized_host = hostutil.normalize_host(request.host)
348
- session_token = db.login(
349
- user_uuid=cred.user,
350
- credential_uuid=cred.uuid,
351
- sign_count=new_sign_count,
352
- host=normalized_host,
353
- ip=request.ip,
354
- user_agent=request.user_agent,
355
- expiry=expires(),
356
- )
357
330
 
358
331
  # Complete the remote auth request (notifies the waiting device)
332
+ cred = db.data().credentials[ctx.session.credential_uuid]
359
333
  completed = await remoteauth.instance.complete_request(
360
334
  token=request.key,
361
335
  session_token=session_token,
362
- user_uuid=cred.user,
336
+ user_uuid=ctx.user.uuid,
363
337
  credential_uuid=cred.uuid,
364
338
  reset_token=reset_token,
365
339
  )
paskia/fastapi/reset.py CHANGED
@@ -10,8 +10,6 @@ display name. If multiple users match, they are listed and the command
10
10
  aborts. A new one-time reset link is always created.
11
11
  """
12
12
 
13
- from __future__ import annotations
14
-
15
13
  import asyncio
16
14
  from uuid import UUID
17
15
 
@@ -0,0 +1,22 @@
1
+ """FastAPI response utilities for msgspec.Struct serialization."""
2
+
3
+ import msgspec
4
+ from fastapi import Response
5
+
6
+
7
+ class MsgspecResponse(Response):
8
+ """Response that uses msgspec for JSON encoding.
9
+
10
+ Use this for returning msgspec.Struct, dict, or list with proper serialization.
11
+ """
12
+
13
+ media_type = "application/json"
14
+
15
+ def __init__(
16
+ self,
17
+ content: msgspec.Struct | dict | list,
18
+ status_code: int = 200,
19
+ headers: dict | None = None,
20
+ ):
21
+ body = msgspec.json.encode(content)
22
+ super().__init__(content=body, status_code=status_code, headers=headers)
paskia/fastapi/user.py CHANGED
@@ -1,4 +1,4 @@
1
- from datetime import timezone
1
+ from datetime import UTC
2
2
  from uuid import UUID
3
3
 
4
4
  from fastapi import (
@@ -43,7 +43,7 @@ async def user_update_display_name(
43
43
  status_code=401, detail="Authentication Required", mode="login"
44
44
  )
45
45
  host = request.headers.get("host")
46
- ctx = db.get_session_context(auth, host)
46
+ ctx = db.data().session_ctx(auth, host)
47
47
  if not ctx:
48
48
  raise authz.AuthException(
49
49
  status_code=401, detail="Session expired", mode="login"
@@ -62,7 +62,7 @@ async def api_logout_all(request: Request, response: Response, auth=AUTH_COOKIE)
62
62
  if not auth:
63
63
  return {"message": "Already logged out"}
64
64
  host = request.headers.get("host")
65
- ctx = db.get_session_context(auth, host)
65
+ ctx = db.data().session_ctx(auth, host)
66
66
  if not ctx:
67
67
  raise authz.AuthException(
68
68
  status_code=401, detail="Session expired", mode="login"
@@ -84,14 +84,14 @@ async def api_delete_session(
84
84
  status_code=401, detail="Authentication Required", mode="login"
85
85
  )
86
86
  host = request.headers.get("host")
87
- ctx = db.get_session_context(auth, host)
87
+ ctx = db.data().session_ctx(auth, host)
88
88
  if not ctx:
89
89
  raise authz.AuthException(
90
90
  status_code=401, detail="Session expired", mode="login"
91
91
  )
92
92
 
93
93
  target_session = db.data().sessions.get(session_id)
94
- if not target_session or target_session.user != ctx.user.uuid:
94
+ if not target_session or target_session.user_uuid != ctx.user.uuid:
95
95
  raise HTTPException(status_code=404, detail="Session not found")
96
96
 
97
97
  db.delete_session(session_id, ctx=ctx)
@@ -141,8 +141,8 @@ async def api_create_link(
141
141
  "message": "Registration link generated successfully",
142
142
  "url": url,
143
143
  "expires": (
144
- expiry.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
144
+ expiry.astimezone(UTC).isoformat().replace("+00:00", "Z")
145
145
  if expiry.tzinfo
146
- else expiry.replace(tzinfo=timezone.utc).isoformat().replace("+00:00", "Z")
146
+ else expiry.replace(tzinfo=UTC).isoformat().replace("+00:00", "Z")
147
147
  ),
148
148
  }
paskia/fastapi/ws.py CHANGED
@@ -1,13 +1,13 @@
1
1
  from fastapi import FastAPI, WebSocket
2
2
 
3
3
  from paskia import db
4
- from paskia.authsession import expires, get_reset
4
+ from paskia.authsession import get_reset
5
5
  from paskia.fastapi import authz, remote
6
6
  from paskia.fastapi.session import AUTH_COOKIE, infodict
7
- from paskia.fastapi.wschat import authenticate_chat, register_chat
7
+ from paskia.fastapi.wschat import authenticate_and_login, register_chat
8
8
  from paskia.fastapi.wsutil import validate_origin, websocket_error_handler
9
9
  from paskia.globals import passkey
10
- from paskia.util import hostutil, passphrase
10
+ from paskia.util import passphrase
11
11
 
12
12
  # Create a FastAPI subapp for WebSocket endpoints
13
13
  app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
@@ -38,15 +38,15 @@ async def websocket_register_add(
38
38
  f"The reset link for {passkey.instance.rp_name} is invalid or has expired"
39
39
  )
40
40
  s = get_reset(reset)
41
- user_uuid = s.user
41
+ user_uuid = s.user_uuid
42
42
  else:
43
43
  # Require recent authentication for adding a new passkey
44
44
  ctx = await authz.verify(auth, perm=[], host=host, max_age="5m")
45
- user_uuid = ctx.session.user
45
+ user_uuid = ctx.session.user_uuid
46
46
  s = ctx.session
47
47
 
48
48
  # Get user information and determine effective user_name for this registration
49
- user = db.data().users.get(user_uuid)
49
+ user = db.data().users[user_uuid]
50
50
  user_name = user.display_name
51
51
  if name is not None:
52
52
  stripped = name.strip()
@@ -59,7 +59,7 @@ async def websocket_register_add(
59
59
 
60
60
  # Create a new session and store everything in database
61
61
  metadata = infodict(ws, "authenticated")
62
- token = db.create_credential_session( # type: ignore[attr-defined]
62
+ token = db.create_credential_session(
63
63
  user_uuid=user_uuid,
64
64
  credential=credential,
65
65
  reset_key=(s.key if reset is not None else None),
@@ -89,43 +89,20 @@ async def websocket_authenticate(ws: WebSocket, auth=AUTH_COOKIE):
89
89
 
90
90
  # If there's an existing session, restrict to that user's credentials (reauth)
91
91
  session_user_uuid = None
92
- credential_ids = None
93
92
  if auth:
94
- ctx = db.get_session_context(auth, host)
95
- if ctx:
96
- session_user_uuid = ctx.user.uuid
97
- credential_ids = db.get_user_credential_ids(session_user_uuid) or None
93
+ existing_ctx = db.data().session_ctx(auth, host)
94
+ if existing_ctx:
95
+ session_user_uuid = existing_ctx.user.uuid
98
96
 
99
- cred, new_sign_count = await authenticate_chat(ws, origin, credential_ids)
97
+ ctx = await authenticate_and_login(ws, auth)
100
98
 
101
99
  # If reauth mode, verify the credential belongs to the session's user
102
- if session_user_uuid and cred.user != session_user_uuid:
100
+ if session_user_uuid and ctx.user.uuid != session_user_uuid:
103
101
  raise ValueError("This passkey belongs to a different account")
104
102
 
105
- # Create session and update user/credential in a single transaction
106
- assert cred.uuid is not None
107
- metadata = infodict(ws, "auth")
108
- normalized_host = hostutil.normalize_host(host)
109
- if not normalized_host:
110
- raise ValueError("Host required for session creation")
111
- hostname = normalized_host.split(":")[0]
112
- rp_id = passkey.instance.rp_id
113
- if not (hostname == rp_id or hostname.endswith(f".{rp_id}")):
114
- raise ValueError(f"Host must be the same as or a subdomain of {rp_id}")
115
-
116
- token = db.login(
117
- user_uuid=cred.user,
118
- credential_uuid=cred.uuid,
119
- sign_count=new_sign_count,
120
- host=normalized_host,
121
- ip=metadata["ip"],
122
- user_agent=metadata["user_agent"],
123
- expiry=expires(),
124
- )
125
-
126
103
  await ws.send_json(
127
104
  {
128
- "user": str(cred.user),
129
- "session_token": token,
105
+ "user": str(ctx.user.uuid),
106
+ "session_token": ctx.session.key,
130
107
  }
131
108
  )
paskia/fastapi/wschat.py CHANGED
@@ -7,8 +7,12 @@ from uuid import UUID
7
7
  from fastapi import WebSocket
8
8
 
9
9
  from paskia import db
10
- from paskia.db import Credential
10
+ from paskia.authsession import expires
11
+ from paskia.db import Credential, SessionContext
12
+ from paskia.fastapi.session import infodict
13
+ from paskia.fastapi.wsutil import validate_origin
11
14
  from paskia.globals import passkey
15
+ from paskia.util import hostutil
12
16
 
13
17
 
14
18
  async def register_chat(
@@ -31,7 +35,6 @@ async def register_chat(
31
35
 
32
36
  async def authenticate_chat(
33
37
  ws: WebSocket,
34
- origin: str,
35
38
  credential_ids: list[bytes] | None = None,
36
39
  ) -> tuple[Credential, int]:
37
40
  """Run WebAuthn authentication flow and return the credential and new sign count.
@@ -39,6 +42,7 @@ async def authenticate_chat(
39
42
  Returns:
40
43
  tuple of (credential, new_sign_count) where new_sign_count comes from WebAuthn verification
41
44
  """
45
+ origin = validate_origin(ws)
42
46
  options, challenge = passkey.instance.auth_generate_options(
43
47
  credential_ids=credential_ids
44
48
  )
@@ -60,3 +64,52 @@ async def authenticate_chat(
60
64
 
61
65
  verification = passkey.instance.auth_verify(authcred, challenge, cred, origin)
62
66
  return cred, verification.new_sign_count
67
+
68
+
69
+ async def authenticate_and_login(
70
+ ws: WebSocket,
71
+ auth: str | None = None,
72
+ ) -> SessionContext:
73
+ """Run WebAuthn authentication flow, create session, and return the session context.
74
+
75
+ If auth is provided, restrict authentication to credentials of that session's user.
76
+
77
+ Returns:
78
+ SessionContext for the authenticated session
79
+ """
80
+ origin = validate_origin(ws)
81
+ host = origin.split("://", 1)[1]
82
+ normalized_host = hostutil.normalize_host(host)
83
+ if not normalized_host:
84
+ raise ValueError("Host required for session creation")
85
+ hostname = normalized_host.split(":")[0]
86
+ rp_id = passkey.instance.rp_id
87
+ if not (hostname == rp_id or hostname.endswith(f".{rp_id}")):
88
+ raise ValueError(f"Host must be the same as or a subdomain of {rp_id}")
89
+ metadata = infodict(ws, "auth")
90
+
91
+ # Get credential IDs if restricting to a user's credentials
92
+ credential_ids = None
93
+ if auth:
94
+ existing_ctx = db.data().session_ctx(auth, host)
95
+ if existing_ctx:
96
+ credential_ids = db.get_user_credential_ids(existing_ctx.user.uuid) or None
97
+
98
+ cred, new_sign_count = await authenticate_chat(ws, credential_ids)
99
+
100
+ # Create session and update user/credential
101
+ token = db.login(
102
+ user_uuid=cred.user_uuid,
103
+ credential_uuid=cred.uuid,
104
+ sign_count=new_sign_count,
105
+ host=normalized_host,
106
+ ip=metadata["ip"],
107
+ user_agent=metadata["user_agent"],
108
+ expiry=expires(),
109
+ )
110
+
111
+ # Fetch and return the full session context
112
+ ctx = db.data().session_ctx(token, normalized_host)
113
+ if not ctx:
114
+ raise ValueError("Failed to create session context")
115
+ return ctx