flyfun-common 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyfun_common/__init__.py +1 -0
- flyfun_common/admin.py +92 -0
- flyfun_common/auth/__init__.py +26 -0
- flyfun_common/auth/config.py +120 -0
- flyfun_common/auth/jwt_utils.py +28 -0
- flyfun_common/auth/router.py +372 -0
- flyfun_common/costs.py +79 -0
- flyfun_common/credentials.py +46 -0
- flyfun_common/db/__init__.py +44 -0
- flyfun_common/db/deps.py +127 -0
- flyfun_common/db/engine.py +100 -0
- flyfun_common/db/models.py +83 -0
- flyfun_common/encryption.py +43 -0
- flyfun_common-0.1.0.dist-info/METADATA +80 -0
- flyfun_common-0.1.0.dist-info/RECORD +17 -0
- flyfun_common-0.1.0.dist-info/WHEEL +4 -0
- flyfun_common-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Shared user management and auth for flyfun services."""
|
flyfun_common/admin.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""Admin helper utilities: token generation, HMAC verification, user management."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
import hmac as hmac_mod
|
|
7
|
+
import secrets
|
|
8
|
+
import time
|
|
9
|
+
import uuid
|
|
10
|
+
from base64 import urlsafe_b64encode
|
|
11
|
+
|
|
12
|
+
from fastapi import HTTPException
|
|
13
|
+
from sqlalchemy.orm import Session
|
|
14
|
+
|
|
15
|
+
from flyfun_common.db.models import ApiTokenRow, UserPreferencesRow, UserRow
|
|
16
|
+
|
|
17
|
+
TOKEN_PREFIX = "ff_"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def generate_api_token(prefix: str = TOKEN_PREFIX) -> str:
|
|
21
|
+
"""Generate a random API token with the given prefix (~48 chars total)."""
|
|
22
|
+
return prefix + urlsafe_b64encode(secrets.token_bytes(32)).decode().rstrip("=")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def hash_token(token: str) -> str:
|
|
26
|
+
"""SHA-256 hash a plaintext token for storage."""
|
|
27
|
+
return hashlib.sha256(token.encode()).hexdigest()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def verify_approval_hmac(
|
|
31
|
+
user_id: str, ts: str, sig: str, secret: str, expiry: int
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Verify an HMAC-signed approval link. Raises HTTPException on failure."""
|
|
34
|
+
expected = hmac_mod.new(
|
|
35
|
+
secret.encode(), f"approve:{user_id}:{ts}".encode(), hashlib.sha256
|
|
36
|
+
).hexdigest()
|
|
37
|
+
|
|
38
|
+
if not hmac_mod.compare_digest(sig, expected):
|
|
39
|
+
raise HTTPException(status_code=403, detail="Invalid approval link")
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
link_time = int(ts)
|
|
43
|
+
except ValueError:
|
|
44
|
+
raise HTTPException(status_code=400, detail="Invalid timestamp")
|
|
45
|
+
|
|
46
|
+
age = time.time() - link_time
|
|
47
|
+
if age > expiry:
|
|
48
|
+
raise HTTPException(status_code=410, detail="Approval link expired")
|
|
49
|
+
if age < 0:
|
|
50
|
+
raise HTTPException(status_code=400, detail="Invalid timestamp")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def create_agent_user(
|
|
54
|
+
db: Session, name: str, prefix: str = TOKEN_PREFIX
|
|
55
|
+
) -> tuple[UserRow, str]:
|
|
56
|
+
"""Create a bot/agent user with an initial API token.
|
|
57
|
+
|
|
58
|
+
Returns (user, plaintext_token). The plaintext token cannot be retrieved later.
|
|
59
|
+
"""
|
|
60
|
+
user_id = f"agent-{uuid.uuid4().hex[:12]}"
|
|
61
|
+
user = UserRow(
|
|
62
|
+
id=user_id,
|
|
63
|
+
provider="api_token",
|
|
64
|
+
provider_sub=uuid.uuid4().hex,
|
|
65
|
+
email="",
|
|
66
|
+
display_name=name,
|
|
67
|
+
approved=True,
|
|
68
|
+
)
|
|
69
|
+
db.add(user)
|
|
70
|
+
db.flush()
|
|
71
|
+
db.add(UserPreferencesRow(user_id=user_id))
|
|
72
|
+
|
|
73
|
+
plaintext = generate_api_token(prefix)
|
|
74
|
+
db.add(
|
|
75
|
+
ApiTokenRow(
|
|
76
|
+
user_id=user_id,
|
|
77
|
+
token_hash=hash_token(plaintext),
|
|
78
|
+
name=name,
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
db.flush()
|
|
82
|
+
return user, plaintext
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def approve_user(db: Session, user_id: str) -> UserRow:
|
|
86
|
+
"""Approve a user account. Raises HTTPException if not found."""
|
|
87
|
+
user = db.query(UserRow).filter(UserRow.id == user_id).first()
|
|
88
|
+
if not user:
|
|
89
|
+
raise HTTPException(status_code=404, detail="User not found")
|
|
90
|
+
user.approved = True
|
|
91
|
+
db.flush()
|
|
92
|
+
return user
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Auth utilities: JWT, OAuth, config."""
|
|
2
|
+
|
|
3
|
+
from flyfun_common.auth.config import (
|
|
4
|
+
COOKIE_NAME,
|
|
5
|
+
COOKIE_DOMAIN,
|
|
6
|
+
SUPPORTED_PROVIDERS,
|
|
7
|
+
is_dev_mode,
|
|
8
|
+
get_jwt_secret,
|
|
9
|
+
get_registered_providers,
|
|
10
|
+
create_oauth,
|
|
11
|
+
)
|
|
12
|
+
from flyfun_common.auth.jwt_utils import create_token, decode_token
|
|
13
|
+
from flyfun_common.auth.router import create_auth_router
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"COOKIE_NAME",
|
|
17
|
+
"COOKIE_DOMAIN",
|
|
18
|
+
"SUPPORTED_PROVIDERS",
|
|
19
|
+
"is_dev_mode",
|
|
20
|
+
"get_jwt_secret",
|
|
21
|
+
"get_registered_providers",
|
|
22
|
+
"create_oauth",
|
|
23
|
+
"create_token",
|
|
24
|
+
"decode_token",
|
|
25
|
+
"create_auth_router",
|
|
26
|
+
]
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Shared auth configuration: cookie, JWT secret, OAuth providers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
from authlib.integrations.starlette_client import OAuth
|
|
8
|
+
|
|
9
|
+
# Unified cookie name — same across all flyfun services
|
|
10
|
+
COOKIE_NAME = "flyfun_auth"
|
|
11
|
+
|
|
12
|
+
# Cookie domain — set to .flyfun.aero in prod for cross-subdomain SSO
|
|
13
|
+
COOKIE_DOMAIN: str | None = None # computed at runtime
|
|
14
|
+
|
|
15
|
+
_DEV_JWT_SECRET = "dev-insecure-jwt-secret-do-not-use-in-production"
|
|
16
|
+
|
|
17
|
+
# Providers that can be registered
|
|
18
|
+
SUPPORTED_PROVIDERS = ("google", "apple")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def is_dev_mode() -> bool:
|
|
22
|
+
return os.environ.get("ENVIRONMENT", "development") != "production"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_cookie_domain() -> str | None:
|
|
26
|
+
"""Return .flyfun.aero in production (enables SSO), None in dev."""
|
|
27
|
+
if is_dev_mode():
|
|
28
|
+
return None
|
|
29
|
+
return os.environ.get("COOKIE_DOMAIN", ".flyfun.aero")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_jwt_secret() -> str:
|
|
33
|
+
secret = os.environ.get("JWT_SECRET")
|
|
34
|
+
if secret:
|
|
35
|
+
if not is_dev_mode() and secret == _DEV_JWT_SECRET:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
"Production must use a unique JWT_SECRET, not the dev default"
|
|
38
|
+
)
|
|
39
|
+
return secret
|
|
40
|
+
if is_dev_mode():
|
|
41
|
+
return _DEV_JWT_SECRET
|
|
42
|
+
raise ValueError("JWT_SECRET environment variable must be set in production")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _apple_client_secret() -> str:
|
|
46
|
+
"""Generate a short-lived JWT client_secret for Sign in with Apple.
|
|
47
|
+
|
|
48
|
+
Apple requires the client_secret to be an ES256-signed JWT containing:
|
|
49
|
+
- iss: Team ID
|
|
50
|
+
- sub: Client ID (Service ID)
|
|
51
|
+
- aud: https://appleid.apple.com
|
|
52
|
+
- iat/exp: issued/expiry (max 180 days)
|
|
53
|
+
|
|
54
|
+
Signed with the private key from Apple Developer Console.
|
|
55
|
+
"""
|
|
56
|
+
import time
|
|
57
|
+
|
|
58
|
+
import jwt # PyJWT
|
|
59
|
+
|
|
60
|
+
team_id = os.environ.get("APPLE_TEAM_ID", "")
|
|
61
|
+
key_id = os.environ.get("APPLE_KEY_ID", "")
|
|
62
|
+
client_id = os.environ.get("APPLE_CLIENT_ID", "")
|
|
63
|
+
private_key = os.environ.get("APPLE_PRIVATE_KEY", "")
|
|
64
|
+
|
|
65
|
+
if not all([team_id, key_id, client_id, private_key]):
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"Apple Sign In requires APPLE_TEAM_ID, APPLE_KEY_ID, "
|
|
68
|
+
"APPLE_CLIENT_ID, and APPLE_PRIVATE_KEY environment variables"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Replace literal \n with actual newlines (env vars can't contain real newlines)
|
|
72
|
+
private_key = private_key.replace("\\n", "\n")
|
|
73
|
+
|
|
74
|
+
now = int(time.time())
|
|
75
|
+
payload = {
|
|
76
|
+
"iss": team_id,
|
|
77
|
+
"sub": client_id,
|
|
78
|
+
"aud": "https://appleid.apple.com",
|
|
79
|
+
"iat": now,
|
|
80
|
+
"exp": now + 86400 * 180, # 180 days (Apple's max)
|
|
81
|
+
}
|
|
82
|
+
return jwt.encode(payload, private_key, algorithm="ES256", headers={"kid": key_id})
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def create_oauth() -> OAuth:
|
|
86
|
+
"""Create OAuth registry with available providers.
|
|
87
|
+
|
|
88
|
+
Providers are registered only if their client_id env var is set.
|
|
89
|
+
"""
|
|
90
|
+
oauth = OAuth()
|
|
91
|
+
|
|
92
|
+
# Google
|
|
93
|
+
if os.environ.get("GOOGLE_CLIENT_ID"):
|
|
94
|
+
oauth.register(
|
|
95
|
+
name="google",
|
|
96
|
+
client_id=os.environ.get("GOOGLE_CLIENT_ID"),
|
|
97
|
+
client_secret=os.environ.get("GOOGLE_CLIENT_SECRET", ""),
|
|
98
|
+
server_metadata_url="https://accounts.google.com/.well-known/openid-configuration",
|
|
99
|
+
client_kwargs={"scope": "openid email profile"},
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Apple — Sign in with Apple (OIDC)
|
|
103
|
+
if os.environ.get("APPLE_CLIENT_ID"):
|
|
104
|
+
oauth.register(
|
|
105
|
+
name="apple",
|
|
106
|
+
client_id=os.environ.get("APPLE_CLIENT_ID"),
|
|
107
|
+
client_secret=_apple_client_secret(),
|
|
108
|
+
server_metadata_url="https://appleid.apple.com/.well-known/openid-configuration",
|
|
109
|
+
client_kwargs={
|
|
110
|
+
"scope": "openid email name",
|
|
111
|
+
"response_mode": "form_post",
|
|
112
|
+
},
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
return oauth
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def get_registered_providers(oauth: OAuth) -> list[str]:
|
|
119
|
+
"""Return the list of provider names that were registered."""
|
|
120
|
+
return [p for p in SUPPORTED_PROVIDERS if hasattr(oauth, p)]
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""JWT token creation and validation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from datetime import datetime, timedelta, timezone
|
|
6
|
+
|
|
7
|
+
import jwt
|
|
8
|
+
|
|
9
|
+
JWT_ALGORITHM = "HS256"
|
|
10
|
+
JWT_EXPIRY_DAYS = 7
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def create_token(user_id: str, email: str, name: str, secret: str) -> str:
|
|
14
|
+
"""Create a signed JWT with user claims and 7-day expiry."""
|
|
15
|
+
now = datetime.now(timezone.utc)
|
|
16
|
+
payload = {
|
|
17
|
+
"sub": user_id,
|
|
18
|
+
"email": email,
|
|
19
|
+
"name": name,
|
|
20
|
+
"iat": now,
|
|
21
|
+
"exp": now + timedelta(days=JWT_EXPIRY_DAYS),
|
|
22
|
+
}
|
|
23
|
+
return jwt.encode(payload, secret, algorithm=JWT_ALGORITHM)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def decode_token(token: str, secret: str) -> dict:
|
|
27
|
+
"""Decode and validate a JWT."""
|
|
28
|
+
return jwt.decode(token, secret, algorithms=[JWT_ALGORITHM])
|
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
"""Shared OAuth auth router: login, callback, logout, /auth/me.
|
|
2
|
+
|
|
3
|
+
Each app calls create_auth_router() and mounts it on their FastAPI app.
|
|
4
|
+
The callback creates/updates users in the shared DB and sets the
|
|
5
|
+
cross-subdomain JWT cookie.
|
|
6
|
+
|
|
7
|
+
Supports multiple OAuth providers (Google, Apple, etc.) via generic
|
|
8
|
+
/{provider} routes. Also supports native iOS Sign in with Apple via
|
|
9
|
+
POST /auth/apple/token (identity token validation).
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import json
|
|
15
|
+
import logging
|
|
16
|
+
import os
|
|
17
|
+
import re
|
|
18
|
+
import uuid
|
|
19
|
+
from datetime import datetime, timezone
|
|
20
|
+
from urllib.parse import quote
|
|
21
|
+
|
|
22
|
+
import jwt as pyjwt
|
|
23
|
+
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
24
|
+
from fastapi.responses import RedirectResponse
|
|
25
|
+
from pydantic import BaseModel
|
|
26
|
+
from sqlalchemy.orm import Session
|
|
27
|
+
|
|
28
|
+
from flyfun_common.auth.config import (
|
|
29
|
+
COOKIE_NAME,
|
|
30
|
+
SUPPORTED_PROVIDERS,
|
|
31
|
+
create_oauth,
|
|
32
|
+
get_cookie_domain,
|
|
33
|
+
get_jwt_secret,
|
|
34
|
+
is_dev_mode,
|
|
35
|
+
)
|
|
36
|
+
from flyfun_common.auth.jwt_utils import create_token
|
|
37
|
+
from flyfun_common.db.deps import current_user_id, get_db
|
|
38
|
+
from flyfun_common.db.models import UserRow
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
# Apple's JWKS endpoint for verifying identity tokens
|
|
43
|
+
_APPLE_JWKS_URL = "https://appleid.apple.com/auth/keys"
|
|
44
|
+
_apple_jwks_client: pyjwt.PyJWKClient | None = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _get_apple_jwks_client() -> pyjwt.PyJWKClient:
|
|
48
|
+
"""Lazily create a cached JWKS client for Apple's public keys."""
|
|
49
|
+
global _apple_jwks_client
|
|
50
|
+
if _apple_jwks_client is None:
|
|
51
|
+
_apple_jwks_client = pyjwt.PyJWKClient(_APPLE_JWKS_URL, cache_keys=True)
|
|
52
|
+
return _apple_jwks_client
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _extract_userinfo(provider: str, token: dict) -> tuple[str, str, str]:
|
|
56
|
+
"""Extract (sub, email, display_name) from an OAuth token response.
|
|
57
|
+
|
|
58
|
+
Each provider returns user data differently:
|
|
59
|
+
- Google: standard OIDC userinfo with sub, email, name
|
|
60
|
+
- Apple: sub/email in id_token claims; name only on first login via
|
|
61
|
+
a separate 'user' JSON field in the POST body
|
|
62
|
+
"""
|
|
63
|
+
if provider == "apple":
|
|
64
|
+
# Apple puts claims in the id_token (parsed by authlib into userinfo)
|
|
65
|
+
userinfo = token.get("userinfo") or {}
|
|
66
|
+
sub = userinfo.get("sub", "")
|
|
67
|
+
email = userinfo.get("email", "")
|
|
68
|
+
# Apple only sends the user's name on first authorization.
|
|
69
|
+
# It comes as a JSON blob in the POST body 'user' parameter,
|
|
70
|
+
# which authlib does NOT parse automatically — we handle it in
|
|
71
|
+
# the callback and pass it via the token dict.
|
|
72
|
+
user_data = token.get("_apple_user", {})
|
|
73
|
+
name_parts = user_data.get("name", {})
|
|
74
|
+
first = name_parts.get("firstName", "")
|
|
75
|
+
last = name_parts.get("lastName", "")
|
|
76
|
+
name = f"{first} {last}".strip() or email
|
|
77
|
+
return sub, email, name
|
|
78
|
+
|
|
79
|
+
# Default: Google and other standard OIDC providers
|
|
80
|
+
userinfo = token.get("userinfo")
|
|
81
|
+
if not userinfo:
|
|
82
|
+
raise ValueError(f"No userinfo in token from {provider}")
|
|
83
|
+
return userinfo["sub"], userinfo.get("email", ""), userinfo.get("name", "")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _find_or_create_user(
|
|
87
|
+
db: Session,
|
|
88
|
+
provider: str,
|
|
89
|
+
provider_sub: str,
|
|
90
|
+
email: str,
|
|
91
|
+
name: str,
|
|
92
|
+
on_new_user: callable | None = None,
|
|
93
|
+
request: Request | None = None,
|
|
94
|
+
) -> UserRow:
|
|
95
|
+
"""Find an existing user by (provider, provider_sub) or create a new one.
|
|
96
|
+
|
|
97
|
+
Updates email/name on returning users if changed.
|
|
98
|
+
"""
|
|
99
|
+
user = (
|
|
100
|
+
db.query(UserRow)
|
|
101
|
+
.filter_by(provider=provider, provider_sub=provider_sub)
|
|
102
|
+
.first()
|
|
103
|
+
)
|
|
104
|
+
if user is None:
|
|
105
|
+
user = UserRow(
|
|
106
|
+
id=str(uuid.uuid4()),
|
|
107
|
+
provider=provider,
|
|
108
|
+
provider_sub=provider_sub,
|
|
109
|
+
email=email,
|
|
110
|
+
display_name=name,
|
|
111
|
+
approved=True,
|
|
112
|
+
)
|
|
113
|
+
db.add(user)
|
|
114
|
+
db.flush()
|
|
115
|
+
logger.info("New user created via %s: %s", provider, user.id)
|
|
116
|
+
|
|
117
|
+
if on_new_user and request:
|
|
118
|
+
try:
|
|
119
|
+
on_new_user(user, request, db)
|
|
120
|
+
except Exception:
|
|
121
|
+
logger.warning(
|
|
122
|
+
"on_new_user callback failed for %s", user.id, exc_info=True
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
user.last_login_at = datetime.now(timezone.utc)
|
|
126
|
+
if email and user.email != email:
|
|
127
|
+
user.email = email
|
|
128
|
+
# Only update name if we got a real name (not just the email echoed back)
|
|
129
|
+
if name and name != email and user.display_name != name:
|
|
130
|
+
user.display_name = name
|
|
131
|
+
db.flush()
|
|
132
|
+
return user
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class AppleTokenRequest(BaseModel):
|
|
136
|
+
"""Request body for native iOS Sign in with Apple."""
|
|
137
|
+
|
|
138
|
+
identity_token: str
|
|
139
|
+
first_name: str | None = None
|
|
140
|
+
last_name: str | None = None
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def create_auth_router(
|
|
144
|
+
on_new_user: callable | None = None,
|
|
145
|
+
) -> APIRouter:
|
|
146
|
+
"""Create an auth router.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
on_new_user: Optional callback(user: UserRow, request: Request, db: Session)
|
|
150
|
+
called after a new user is created (e.g. send welcome email).
|
|
151
|
+
"""
|
|
152
|
+
router = APIRouter(prefix="/auth", tags=["auth"])
|
|
153
|
+
oauth = create_oauth()
|
|
154
|
+
|
|
155
|
+
def _get_oauth_client(provider: str):
|
|
156
|
+
"""Get a registered OAuth client, or raise 404."""
|
|
157
|
+
client = getattr(oauth, provider, None)
|
|
158
|
+
if client is None:
|
|
159
|
+
raise HTTPException(
|
|
160
|
+
status_code=404,
|
|
161
|
+
detail=f"Auth provider '{provider}' is not configured",
|
|
162
|
+
)
|
|
163
|
+
return client
|
|
164
|
+
|
|
165
|
+
@router.get("/providers")
|
|
166
|
+
async def list_providers():
|
|
167
|
+
"""Return the list of configured OAuth providers."""
|
|
168
|
+
from flyfun_common.auth.config import get_registered_providers
|
|
169
|
+
|
|
170
|
+
return {"providers": get_registered_providers(oauth)}
|
|
171
|
+
|
|
172
|
+
@router.get("/login/{provider}")
|
|
173
|
+
async def login(
|
|
174
|
+
provider: str,
|
|
175
|
+
request: Request,
|
|
176
|
+
platform: str | None = None,
|
|
177
|
+
scheme: str | None = None,
|
|
178
|
+
):
|
|
179
|
+
if provider not in SUPPORTED_PROVIDERS:
|
|
180
|
+
raise HTTPException(status_code=404, detail=f"Unknown provider: {provider}")
|
|
181
|
+
client = _get_oauth_client(provider)
|
|
182
|
+
redirect_uri = request.url_for("callback", provider=provider)
|
|
183
|
+
if not is_dev_mode():
|
|
184
|
+
redirect_uri = str(redirect_uri).replace("http://", "https://")
|
|
185
|
+
if platform:
|
|
186
|
+
request.session["oauth_platform"] = platform
|
|
187
|
+
if scheme:
|
|
188
|
+
if not re.fullmatch(r"flyfun[a-z0-9\-]*", scheme):
|
|
189
|
+
raise HTTPException(
|
|
190
|
+
status_code=400,
|
|
191
|
+
detail="Invalid URL scheme",
|
|
192
|
+
)
|
|
193
|
+
request.session["oauth_scheme"] = scheme
|
|
194
|
+
return await client.authorize_redirect(request, redirect_uri)
|
|
195
|
+
|
|
196
|
+
@router.get("/callback/{provider}")
|
|
197
|
+
@router.post("/callback/{provider}") # Apple uses form_post (POST)
|
|
198
|
+
async def callback(
|
|
199
|
+
provider: str, request: Request, db: Session = Depends(get_db)
|
|
200
|
+
):
|
|
201
|
+
if provider not in SUPPORTED_PROVIDERS:
|
|
202
|
+
raise HTTPException(status_code=404, detail=f"Unknown provider: {provider}")
|
|
203
|
+
client = _get_oauth_client(provider)
|
|
204
|
+
|
|
205
|
+
try:
|
|
206
|
+
token = await client.authorize_access_token(request)
|
|
207
|
+
except Exception as exc:
|
|
208
|
+
logger.warning("OAuth callback failed for %s: %s", provider, exc)
|
|
209
|
+
raise HTTPException(
|
|
210
|
+
status_code=400, detail="OAuth authentication failed"
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Apple: extract the 'user' JSON from form body (only sent on first auth)
|
|
214
|
+
if provider == "apple":
|
|
215
|
+
form = await request.form()
|
|
216
|
+
user_json = form.get("user")
|
|
217
|
+
if user_json:
|
|
218
|
+
try:
|
|
219
|
+
token["_apple_user"] = json.loads(user_json)
|
|
220
|
+
except (json.JSONDecodeError, TypeError):
|
|
221
|
+
pass
|
|
222
|
+
|
|
223
|
+
try:
|
|
224
|
+
provider_sub, email, name = _extract_userinfo(provider, token)
|
|
225
|
+
except (ValueError, KeyError) as exc:
|
|
226
|
+
logger.warning("Failed to extract userinfo from %s: %s", provider, exc)
|
|
227
|
+
raise HTTPException(
|
|
228
|
+
status_code=400, detail=f"No user info from {provider}"
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if not provider_sub:
|
|
232
|
+
raise HTTPException(
|
|
233
|
+
status_code=400, detail=f"No subject identifier from {provider}"
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
user = _find_or_create_user(
|
|
237
|
+
db, provider, provider_sub, email, name,
|
|
238
|
+
on_new_user=on_new_user, request=request,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
if not user.approved:
|
|
242
|
+
return RedirectResponse(url="/login.html?status=pending", status_code=302)
|
|
243
|
+
|
|
244
|
+
jwt_token = create_token(
|
|
245
|
+
user.id, user.email, user.display_name, get_jwt_secret()
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# iOS/native app: redirect to app-specific custom URL scheme
|
|
249
|
+
platform = request.session.pop("oauth_platform", None)
|
|
250
|
+
if platform == "ios":
|
|
251
|
+
scheme = request.session.pop("oauth_scheme", "flyfun")
|
|
252
|
+
redirect_url = f"{scheme}://auth/callback?token={quote(jwt_token)}"
|
|
253
|
+
return RedirectResponse(url=redirect_url, status_code=302)
|
|
254
|
+
|
|
255
|
+
response = RedirectResponse(url="/", status_code=302)
|
|
256
|
+
_set_session_cookie(response, jwt_token)
|
|
257
|
+
return response
|
|
258
|
+
|
|
259
|
+
# --- Native iOS Sign in with Apple ---
|
|
260
|
+
|
|
261
|
+
@router.post("/apple/token")
|
|
262
|
+
async def apple_token(
|
|
263
|
+
body: AppleTokenRequest,
|
|
264
|
+
request: Request,
|
|
265
|
+
db: Session = Depends(get_db),
|
|
266
|
+
):
|
|
267
|
+
"""Validate an Apple identity token from a native iOS app.
|
|
268
|
+
|
|
269
|
+
The iOS app uses ASAuthorizationAppleIDProvider to get an identity
|
|
270
|
+
token (JWT signed by Apple), then sends it here. We verify the
|
|
271
|
+
signature against Apple's public keys and extract the user info.
|
|
272
|
+
|
|
273
|
+
Returns a flyfun JWT token for the app to use in subsequent requests.
|
|
274
|
+
"""
|
|
275
|
+
# Build the list of accepted audiences.
|
|
276
|
+
# APPLE_APP_IDS: comma-separated bundle IDs for all iOS apps
|
|
277
|
+
# e.g. "aero.flyfun.weather,aero.flyfun.customs"
|
|
278
|
+
# Falls back to APPLE_APP_ID (single app) or APPLE_CLIENT_ID (web).
|
|
279
|
+
app_ids_raw = os.environ.get("APPLE_APP_IDS", "")
|
|
280
|
+
if app_ids_raw:
|
|
281
|
+
expected_audiences = [a.strip() for a in app_ids_raw.split(",") if a.strip()]
|
|
282
|
+
else:
|
|
283
|
+
single = os.environ.get(
|
|
284
|
+
"APPLE_APP_ID",
|
|
285
|
+
os.environ.get("APPLE_CLIENT_ID", ""),
|
|
286
|
+
)
|
|
287
|
+
expected_audiences = [single] if single else []
|
|
288
|
+
|
|
289
|
+
if not expected_audiences:
|
|
290
|
+
raise HTTPException(
|
|
291
|
+
status_code=503,
|
|
292
|
+
detail="Apple Sign In is not configured on this server",
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
try:
|
|
296
|
+
jwks_client = _get_apple_jwks_client()
|
|
297
|
+
signing_key = jwks_client.get_signing_key_from_jwt(body.identity_token)
|
|
298
|
+
|
|
299
|
+
# PyJWT accepts a list of audiences — token is valid if its
|
|
300
|
+
# aud matches ANY of them.
|
|
301
|
+
claims = pyjwt.decode(
|
|
302
|
+
body.identity_token,
|
|
303
|
+
signing_key.key,
|
|
304
|
+
algorithms=["RS256"],
|
|
305
|
+
audience=expected_audiences,
|
|
306
|
+
issuer="https://appleid.apple.com",
|
|
307
|
+
)
|
|
308
|
+
except pyjwt.ExpiredSignatureError:
|
|
309
|
+
raise HTTPException(status_code=401, detail="Identity token has expired")
|
|
310
|
+
except pyjwt.InvalidTokenError as exc:
|
|
311
|
+
logger.warning("Apple identity token validation failed: %s", exc)
|
|
312
|
+
raise HTTPException(status_code=401, detail="Invalid identity token")
|
|
313
|
+
|
|
314
|
+
sub = claims.get("sub", "")
|
|
315
|
+
email = claims.get("email", "")
|
|
316
|
+
if not sub:
|
|
317
|
+
raise HTTPException(status_code=400, detail="No subject in identity token")
|
|
318
|
+
|
|
319
|
+
# Build display name from optional first_name/last_name
|
|
320
|
+
# (only available on first iOS authorization)
|
|
321
|
+
name_parts = [p for p in [body.first_name, body.last_name] if p]
|
|
322
|
+
name = " ".join(name_parts) or email
|
|
323
|
+
|
|
324
|
+
user = _find_or_create_user(
|
|
325
|
+
db, "apple", sub, email, name,
|
|
326
|
+
on_new_user=on_new_user, request=request,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
if not user.approved:
|
|
330
|
+
raise HTTPException(status_code=403, detail="Account is not approved")
|
|
331
|
+
|
|
332
|
+
jwt_token = create_token(
|
|
333
|
+
user.id, user.email, user.display_name, get_jwt_secret()
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
return {"token": jwt_token, "user_id": user.id}
|
|
337
|
+
|
|
338
|
+
@router.post("/logout")
|
|
339
|
+
async def logout():
|
|
340
|
+
response = RedirectResponse(url="/login.html", status_code=302)
|
|
341
|
+
response.delete_cookie(COOKIE_NAME, path="/", domain=get_cookie_domain())
|
|
342
|
+
return response
|
|
343
|
+
|
|
344
|
+
@router.get("/me")
|
|
345
|
+
async def get_me(
|
|
346
|
+
user_id: str = Depends(current_user_id), db: Session = Depends(get_db)
|
|
347
|
+
):
|
|
348
|
+
user = db.get(UserRow, user_id)
|
|
349
|
+
if not user:
|
|
350
|
+
raise HTTPException(status_code=401, detail="User not found")
|
|
351
|
+
return {
|
|
352
|
+
"id": user.id,
|
|
353
|
+
"email": user.email,
|
|
354
|
+
"name": user.display_name,
|
|
355
|
+
"approved": user.approved,
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
return router
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def _set_session_cookie(response: RedirectResponse, token: str) -> None:
|
|
362
|
+
secure = not is_dev_mode()
|
|
363
|
+
response.set_cookie(
|
|
364
|
+
key=COOKIE_NAME,
|
|
365
|
+
value=token,
|
|
366
|
+
httponly=True,
|
|
367
|
+
samesite="lax",
|
|
368
|
+
secure=secure,
|
|
369
|
+
path="/",
|
|
370
|
+
domain=get_cookie_domain(),
|
|
371
|
+
max_age=7 * 24 * 3600,
|
|
372
|
+
)
|
flyfun_common/costs.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Cost ledger utilities: record, query, and check budget."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
|
|
8
|
+
from sqlalchemy import func
|
|
9
|
+
from sqlalchemy.orm import Session
|
|
10
|
+
|
|
11
|
+
from flyfun_common.db.models import CostLedgerRow, UserRow
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def record_cost(
|
|
15
|
+
db: Session,
|
|
16
|
+
user_id: str,
|
|
17
|
+
service: str,
|
|
18
|
+
action: str,
|
|
19
|
+
cost: float,
|
|
20
|
+
metadata: dict | None = None,
|
|
21
|
+
) -> CostLedgerRow:
|
|
22
|
+
"""Record a cost entry in the ledger."""
|
|
23
|
+
row = CostLedgerRow(
|
|
24
|
+
user_id=user_id,
|
|
25
|
+
service=service,
|
|
26
|
+
action=action,
|
|
27
|
+
cost=cost,
|
|
28
|
+
metadata_json=json.dumps(metadata) if metadata else None,
|
|
29
|
+
)
|
|
30
|
+
db.add(row)
|
|
31
|
+
db.flush()
|
|
32
|
+
return row
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def get_total_cost(db: Session, user_id: str, service: str | None = None) -> float:
|
|
36
|
+
"""Sum of all costs for a user, optionally filtered by service."""
|
|
37
|
+
q = db.query(func.coalesce(func.sum(CostLedgerRow.cost), 0.0)).filter(
|
|
38
|
+
CostLedgerRow.user_id == user_id
|
|
39
|
+
)
|
|
40
|
+
if service:
|
|
41
|
+
q = q.filter(CostLedgerRow.service == service)
|
|
42
|
+
return float(q.scalar())
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_cost_since(
|
|
46
|
+
db: Session, user_id: str, since: datetime, service: str | None = None
|
|
47
|
+
) -> float:
|
|
48
|
+
"""Sum of costs since a given time."""
|
|
49
|
+
q = db.query(func.coalesce(func.sum(CostLedgerRow.cost), 0.0)).filter(
|
|
50
|
+
CostLedgerRow.user_id == user_id,
|
|
51
|
+
CostLedgerRow.created_at >= since,
|
|
52
|
+
)
|
|
53
|
+
if service:
|
|
54
|
+
q = q.filter(CostLedgerRow.service == service)
|
|
55
|
+
return float(q.scalar())
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def check_budget(db: Session, user_id: str) -> tuple[float, float]:
|
|
59
|
+
"""Return (total_spent, spending_limit) for a user."""
|
|
60
|
+
total = get_total_cost(db, user_id)
|
|
61
|
+
user = db.get(UserRow, user_id)
|
|
62
|
+
limit = user.spending_limit if user else 500.0
|
|
63
|
+
return total, limit
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_cost_breakdown(
|
|
67
|
+
db: Session,
|
|
68
|
+
user_id: str,
|
|
69
|
+
service: str | None = None,
|
|
70
|
+
since: datetime | None = None,
|
|
71
|
+
limit: int = 50,
|
|
72
|
+
) -> list[CostLedgerRow]:
|
|
73
|
+
"""Return recent cost entries for a user."""
|
|
74
|
+
q = db.query(CostLedgerRow).filter(CostLedgerRow.user_id == user_id)
|
|
75
|
+
if service:
|
|
76
|
+
q = q.filter(CostLedgerRow.service == service)
|
|
77
|
+
if since:
|
|
78
|
+
q = q.filter(CostLedgerRow.created_at >= since)
|
|
79
|
+
return q.order_by(CostLedgerRow.created_at.desc()).limit(limit).all()
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Encrypted credential storage helpers using UserPreferencesRow."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
from sqlalchemy.orm import Session
|
|
9
|
+
|
|
10
|
+
from flyfun_common.db.models import UserPreferencesRow
|
|
11
|
+
from flyfun_common.encryption import decrypt, encrypt
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def save_encrypted_creds(db: Session, user_id: str, creds_dict: dict) -> None:
|
|
17
|
+
"""Encrypt and save credentials to UserPreferencesRow.encrypted_creds_json."""
|
|
18
|
+
row = db.get(UserPreferencesRow, user_id)
|
|
19
|
+
if row is None:
|
|
20
|
+
row = UserPreferencesRow(user_id=user_id)
|
|
21
|
+
db.add(row)
|
|
22
|
+
row.encrypted_creds_json = encrypt(json.dumps(creds_dict))
|
|
23
|
+
db.flush()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def load_encrypted_creds(db: Session, user_id: str) -> dict | None:
|
|
27
|
+
"""Load and decrypt credentials from UserPreferencesRow.
|
|
28
|
+
|
|
29
|
+
Returns the decrypted dict, or None if not configured.
|
|
30
|
+
"""
|
|
31
|
+
row = db.get(UserPreferencesRow, user_id)
|
|
32
|
+
if not row or not row.encrypted_creds_json:
|
|
33
|
+
return None
|
|
34
|
+
try:
|
|
35
|
+
return json.loads(decrypt(row.encrypted_creds_json))
|
|
36
|
+
except Exception:
|
|
37
|
+
logger.warning("Failed to decrypt credentials for user %s", user_id)
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def clear_encrypted_creds(db: Session, user_id: str) -> None:
|
|
42
|
+
"""Clear stored credentials for a user."""
|
|
43
|
+
row = db.get(UserPreferencesRow, user_id)
|
|
44
|
+
if row:
|
|
45
|
+
row.encrypted_creds_json = ""
|
|
46
|
+
db.flush()
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Database models and engine for shared user tables."""
|
|
2
|
+
|
|
3
|
+
from flyfun_common.db.models import (
|
|
4
|
+
Base,
|
|
5
|
+
UserRow,
|
|
6
|
+
ApiTokenRow,
|
|
7
|
+
UserPreferencesRow,
|
|
8
|
+
CostLedgerRow,
|
|
9
|
+
)
|
|
10
|
+
from flyfun_common.db.engine import (
|
|
11
|
+
get_engine,
|
|
12
|
+
reset_engine,
|
|
13
|
+
init_shared_db,
|
|
14
|
+
ensure_dev_user,
|
|
15
|
+
SessionLocal,
|
|
16
|
+
DEV_USER_ID,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Lazy imports to avoid circular dependency (deps → auth.config → auth → router → deps)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def __getattr__(name: str):
|
|
23
|
+
if name in ("get_db", "current_user_id"):
|
|
24
|
+
from flyfun_common.db import deps
|
|
25
|
+
|
|
26
|
+
return getattr(deps, name)
|
|
27
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"Base",
|
|
32
|
+
"UserRow",
|
|
33
|
+
"ApiTokenRow",
|
|
34
|
+
"UserPreferencesRow",
|
|
35
|
+
"CostLedgerRow",
|
|
36
|
+
"get_engine",
|
|
37
|
+
"reset_engine",
|
|
38
|
+
"init_shared_db",
|
|
39
|
+
"ensure_dev_user",
|
|
40
|
+
"SessionLocal",
|
|
41
|
+
"DEV_USER_ID",
|
|
42
|
+
"get_db",
|
|
43
|
+
"current_user_id",
|
|
44
|
+
]
|
flyfun_common/db/deps.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""FastAPI dependencies: DB session and auth.
|
|
2
|
+
|
|
3
|
+
Supports three auth methods (in priority order):
|
|
4
|
+
1. Dev mode bypass → DEV_USER_ID
|
|
5
|
+
2. JWT cookie (set by OAuth login)
|
|
6
|
+
3. Bearer token: JWT or hashed API token (ff_ prefix)
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import hashlib
|
|
12
|
+
from collections.abc import Generator
|
|
13
|
+
from datetime import datetime, timezone
|
|
14
|
+
|
|
15
|
+
import jwt
|
|
16
|
+
from fastapi import Depends, HTTPException, Request
|
|
17
|
+
from sqlalchemy.orm import Session
|
|
18
|
+
|
|
19
|
+
from flyfun_common.auth.config import COOKIE_NAME, get_jwt_secret
|
|
20
|
+
from flyfun_common.auth.jwt_utils import decode_token
|
|
21
|
+
from flyfun_common.db.engine import DEV_USER_ID, SessionLocal, is_dev_mode
|
|
22
|
+
from flyfun_common.db.models import ApiTokenRow, UserRow
|
|
23
|
+
|
|
24
|
+
TOKEN_PREFIX = "ff_"
|
|
25
|
+
_LEGACY_TOKEN_PREFIX = "wb_" # accept old tokens during migration
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_db() -> Generator[Session, None, None]:
|
|
29
|
+
"""Yield a SQLAlchemy session, committing on success or rolling back on error."""
|
|
30
|
+
session = SessionLocal()
|
|
31
|
+
try:
|
|
32
|
+
yield session
|
|
33
|
+
session.commit()
|
|
34
|
+
except Exception:
|
|
35
|
+
session.rollback()
|
|
36
|
+
raise
|
|
37
|
+
finally:
|
|
38
|
+
session.close()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _is_api_token(token: str) -> bool:
|
|
42
|
+
"""Check if a bearer token is an API token (vs a JWT)."""
|
|
43
|
+
return token.startswith(TOKEN_PREFIX) or token.startswith(_LEGACY_TOKEN_PREFIX)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _authenticate_bearer_token(token: str, db: Session) -> str:
|
|
47
|
+
"""Validate a hashed API token against the api_tokens table."""
|
|
48
|
+
if not _is_api_token(token):
|
|
49
|
+
raise HTTPException(status_code=401, detail="Invalid token format")
|
|
50
|
+
|
|
51
|
+
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
|
52
|
+
row = (
|
|
53
|
+
db.query(ApiTokenRow)
|
|
54
|
+
.filter(ApiTokenRow.token_hash == token_hash)
|
|
55
|
+
.first()
|
|
56
|
+
)
|
|
57
|
+
if not row:
|
|
58
|
+
raise HTTPException(status_code=401, detail="Invalid token")
|
|
59
|
+
if row.revoked:
|
|
60
|
+
raise HTTPException(status_code=401, detail="Token revoked")
|
|
61
|
+
if row.expires_at:
|
|
62
|
+
expires = row.expires_at
|
|
63
|
+
if expires.tzinfo is None:
|
|
64
|
+
expires = expires.replace(tzinfo=timezone.utc)
|
|
65
|
+
if expires <= datetime.now(timezone.utc):
|
|
66
|
+
raise HTTPException(status_code=401, detail="Token expired")
|
|
67
|
+
|
|
68
|
+
row.last_used_at = datetime.now(timezone.utc)
|
|
69
|
+
db.flush()
|
|
70
|
+
return row.user_id
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _decode_user_id(request: Request, db: Session) -> str:
|
|
74
|
+
"""Extract user ID from JWT cookie or Bearer token.
|
|
75
|
+
|
|
76
|
+
Priority: dev mode → cookie → Bearer (JWT or API token).
|
|
77
|
+
"""
|
|
78
|
+
if is_dev_mode():
|
|
79
|
+
return DEV_USER_ID
|
|
80
|
+
|
|
81
|
+
secret = get_jwt_secret()
|
|
82
|
+
|
|
83
|
+
# Try JWT cookie
|
|
84
|
+
cookie = request.cookies.get(COOKIE_NAME)
|
|
85
|
+
if cookie:
|
|
86
|
+
try:
|
|
87
|
+
payload = decode_token(cookie, secret)
|
|
88
|
+
return payload["sub"]
|
|
89
|
+
except jwt.ExpiredSignatureError:
|
|
90
|
+
raise HTTPException(status_code=401, detail="Session expired")
|
|
91
|
+
except (jwt.InvalidTokenError, KeyError):
|
|
92
|
+
raise HTTPException(status_code=401, detail="Invalid session")
|
|
93
|
+
|
|
94
|
+
# Try Bearer token
|
|
95
|
+
auth_header = request.headers.get("authorization", "")
|
|
96
|
+
if auth_header.startswith("Bearer "):
|
|
97
|
+
bearer_token = auth_header[7:]
|
|
98
|
+
if not _is_api_token(bearer_token):
|
|
99
|
+
try:
|
|
100
|
+
payload = decode_token(bearer_token, secret)
|
|
101
|
+
return payload["sub"]
|
|
102
|
+
except jwt.ExpiredSignatureError:
|
|
103
|
+
raise HTTPException(status_code=401, detail="Token expired")
|
|
104
|
+
except (jwt.InvalidTokenError, KeyError):
|
|
105
|
+
raise HTTPException(status_code=401, detail="Invalid token")
|
|
106
|
+
return _authenticate_bearer_token(bearer_token, db)
|
|
107
|
+
|
|
108
|
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def current_user_id(
|
|
112
|
+
request: Request,
|
|
113
|
+
db: Session = Depends(get_db),
|
|
114
|
+
) -> str:
|
|
115
|
+
"""Return the authenticated user ID. Raises 401/403 on failure."""
|
|
116
|
+
user_id = _decode_user_id(request, db)
|
|
117
|
+
|
|
118
|
+
if is_dev_mode():
|
|
119
|
+
return user_id
|
|
120
|
+
|
|
121
|
+
user = db.query(UserRow).filter(UserRow.id == user_id).first()
|
|
122
|
+
if not user:
|
|
123
|
+
raise HTTPException(status_code=401, detail="User not found")
|
|
124
|
+
if not user.approved:
|
|
125
|
+
raise HTTPException(status_code=403, detail="Account suspended")
|
|
126
|
+
|
|
127
|
+
return user_id
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Database engine: singleton pattern, SQLite (dev) or MySQL (prod)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
from sqlalchemy import create_engine, event
|
|
9
|
+
from sqlalchemy.engine import Engine
|
|
10
|
+
from sqlalchemy.orm import Session, sessionmaker
|
|
11
|
+
|
|
12
|
+
from flyfun_common.db.models import Base, UserPreferencesRow, UserRow
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
_engine: Engine | None = None
|
|
17
|
+
SessionLocal: sessionmaker[Session] = sessionmaker()
|
|
18
|
+
|
|
19
|
+
DEV_USER_ID = "dev-user-001"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def is_dev_mode() -> bool:
|
|
23
|
+
return os.environ.get("ENVIRONMENT", "development") != "production"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def get_engine(db_url: str | None = None) -> Engine:
|
|
27
|
+
"""Return a singleton SQLAlchemy engine.
|
|
28
|
+
|
|
29
|
+
Dev: sqlite:///{DATA_DIR}/flyfun.db
|
|
30
|
+
Prod: DATABASE_URL env var (MySQL / PostgreSQL)
|
|
31
|
+
"""
|
|
32
|
+
global _engine
|
|
33
|
+
if _engine is not None:
|
|
34
|
+
return _engine
|
|
35
|
+
|
|
36
|
+
if db_url is None:
|
|
37
|
+
if not is_dev_mode():
|
|
38
|
+
db_url = os.environ.get("DATABASE_URL")
|
|
39
|
+
if not db_url:
|
|
40
|
+
raise ValueError("DATABASE_URL must be set in production")
|
|
41
|
+
else:
|
|
42
|
+
data_dir = os.environ.get("DATA_DIR", "data")
|
|
43
|
+
os.makedirs(data_dir, exist_ok=True)
|
|
44
|
+
db_url = f"sqlite:///{data_dir}/flyfun.db"
|
|
45
|
+
|
|
46
|
+
connect_args = {}
|
|
47
|
+
if db_url.startswith("sqlite"):
|
|
48
|
+
connect_args["check_same_thread"] = False
|
|
49
|
+
connect_args["timeout"] = 30
|
|
50
|
+
|
|
51
|
+
_engine = create_engine(db_url, connect_args=connect_args)
|
|
52
|
+
SessionLocal.configure(bind=_engine)
|
|
53
|
+
|
|
54
|
+
if db_url.startswith("sqlite"):
|
|
55
|
+
|
|
56
|
+
@event.listens_for(_engine, "connect")
|
|
57
|
+
def _set_sqlite_pragma(dbapi_conn, _connection_record):
|
|
58
|
+
cursor = dbapi_conn.cursor()
|
|
59
|
+
cursor.execute("PRAGMA journal_mode=WAL")
|
|
60
|
+
cursor.execute("PRAGMA foreign_keys=ON")
|
|
61
|
+
cursor.close()
|
|
62
|
+
|
|
63
|
+
safe_url = db_url.split("@")[-1] if "@" in db_url else db_url.split("///")[-1]
|
|
64
|
+
logger.info("Shared DB engine created: %s", safe_url)
|
|
65
|
+
return _engine
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def reset_engine() -> None:
|
|
69
|
+
"""Reset the singleton engine (for testing)."""
|
|
70
|
+
global _engine
|
|
71
|
+
if _engine is not None:
|
|
72
|
+
_engine.dispose()
|
|
73
|
+
_engine = None
|
|
74
|
+
SessionLocal.configure(bind=None)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def init_shared_db(engine: Engine | None = None) -> None:
|
|
78
|
+
"""Create shared tables (users, api_tokens). Dev only; prod uses Alembic."""
|
|
79
|
+
engine = engine or get_engine()
|
|
80
|
+
Base.metadata.create_all(engine)
|
|
81
|
+
logger.info("Shared tables created")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def ensure_dev_user(session: Session) -> None:
|
|
85
|
+
"""Create dev user for local development."""
|
|
86
|
+
user = session.get(UserRow, DEV_USER_ID)
|
|
87
|
+
if user is None:
|
|
88
|
+
user = UserRow(
|
|
89
|
+
id=DEV_USER_ID,
|
|
90
|
+
provider="local",
|
|
91
|
+
provider_sub="dev",
|
|
92
|
+
email="dev@localhost",
|
|
93
|
+
display_name="Dev User",
|
|
94
|
+
approved=True,
|
|
95
|
+
)
|
|
96
|
+
session.add(user)
|
|
97
|
+
session.flush()
|
|
98
|
+
session.add(UserPreferencesRow(user_id=DEV_USER_ID))
|
|
99
|
+
session.commit()
|
|
100
|
+
logger.info("Dev user created: %s", DEV_USER_ID)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Shared SQLAlchemy models: users, API tokens, preferences, cost ledger.
|
|
2
|
+
|
|
3
|
+
App-specific models (flights, briefings, usage, etc.) stay in each app.
|
|
4
|
+
Apps can add relationships to UserRow via SQLAlchemy backref or explicit
|
|
5
|
+
relationship() on their own models pointing to "users.id".
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from datetime import datetime, timezone
|
|
11
|
+
|
|
12
|
+
from sqlalchemy import Boolean, DateTime, Float, ForeignKey, Integer, String, Text
|
|
13
|
+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Base(DeclarativeBase):
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class UserRow(Base):
|
|
21
|
+
__tablename__ = "users"
|
|
22
|
+
|
|
23
|
+
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
|
24
|
+
provider: Mapped[str] = mapped_column(String(32), default="local")
|
|
25
|
+
provider_sub: Mapped[str] = mapped_column(String(256), default="")
|
|
26
|
+
email: Mapped[str] = mapped_column(String(256), default="")
|
|
27
|
+
display_name: Mapped[str] = mapped_column(String(256), default="")
|
|
28
|
+
approved: Mapped[bool] = mapped_column(Boolean, default=True)
|
|
29
|
+
spending_limit: Mapped[float] = mapped_column(
|
|
30
|
+
Float, default=500.0, server_default="500.0"
|
|
31
|
+
)
|
|
32
|
+
created_at: Mapped[datetime] = mapped_column(
|
|
33
|
+
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
|
34
|
+
)
|
|
35
|
+
last_login_at: Mapped[datetime | None] = mapped_column(
|
|
36
|
+
DateTime(timezone=True), nullable=True, default=None
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ApiTokenRow(Base):
|
|
41
|
+
__tablename__ = "api_tokens"
|
|
42
|
+
|
|
43
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
44
|
+
user_id: Mapped[str] = mapped_column(String(64), index=True)
|
|
45
|
+
token_hash: Mapped[str] = mapped_column(String(64), unique=True, index=True)
|
|
46
|
+
name: Mapped[str] = mapped_column(String(256), default="")
|
|
47
|
+
created_at: Mapped[datetime] = mapped_column(
|
|
48
|
+
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
|
49
|
+
)
|
|
50
|
+
expires_at: Mapped[datetime | None] = mapped_column(
|
|
51
|
+
DateTime(timezone=True), nullable=True, default=None
|
|
52
|
+
)
|
|
53
|
+
last_used_at: Mapped[datetime | None] = mapped_column(
|
|
54
|
+
DateTime(timezone=True), nullable=True, default=None
|
|
55
|
+
)
|
|
56
|
+
revoked: Mapped[bool] = mapped_column(Boolean, default=False)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class UserPreferencesRow(Base):
|
|
60
|
+
__tablename__ = "user_preferences"
|
|
61
|
+
|
|
62
|
+
user_id: Mapped[str] = mapped_column(
|
|
63
|
+
String(64), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
|
|
64
|
+
)
|
|
65
|
+
setup_completed: Mapped[bool] = mapped_column(Boolean, default=False)
|
|
66
|
+
encrypted_creds_json: Mapped[str] = mapped_column(Text, default="")
|
|
67
|
+
app_prefs_json: Mapped[str] = mapped_column(Text, default="{}")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class CostLedgerRow(Base):
|
|
71
|
+
__tablename__ = "cost_ledger"
|
|
72
|
+
|
|
73
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
74
|
+
user_id: Mapped[str] = mapped_column(
|
|
75
|
+
String(64), ForeignKey("users.id", ondelete="CASCADE"), index=True
|
|
76
|
+
)
|
|
77
|
+
service: Mapped[str] = mapped_column(String(64))
|
|
78
|
+
action: Mapped[str] = mapped_column(String(64))
|
|
79
|
+
cost: Mapped[float] = mapped_column(Float)
|
|
80
|
+
metadata_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
81
|
+
created_at: Mapped[datetime] = mapped_column(
|
|
82
|
+
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
|
83
|
+
)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Fernet encryption for storing sensitive credentials at rest."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
import hashlib
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
from cryptography.fernet import Fernet
|
|
10
|
+
|
|
11
|
+
from flyfun_common.auth.config import get_jwt_secret, is_dev_mode
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _get_fernet_key() -> bytes:
|
|
15
|
+
"""Get the Fernet encryption key.
|
|
16
|
+
|
|
17
|
+
Uses CREDENTIAL_ENCRYPTION_KEY env var if set.
|
|
18
|
+
In dev mode, derives a key from JWT_SECRET as a fallback.
|
|
19
|
+
"""
|
|
20
|
+
explicit_key = os.environ.get("CREDENTIAL_ENCRYPTION_KEY")
|
|
21
|
+
if explicit_key:
|
|
22
|
+
return explicit_key.encode()
|
|
23
|
+
|
|
24
|
+
if is_dev_mode():
|
|
25
|
+
digest = hashlib.sha256(get_jwt_secret().encode()).digest()
|
|
26
|
+
return base64.urlsafe_b64encode(digest)
|
|
27
|
+
|
|
28
|
+
raise ValueError(
|
|
29
|
+
"CREDENTIAL_ENCRYPTION_KEY must be set in production. "
|
|
30
|
+
'Generate with: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"'
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def encrypt(plaintext: str) -> str:
|
|
35
|
+
"""Encrypt a plaintext string, returning a Fernet token as a string."""
|
|
36
|
+
f = Fernet(_get_fernet_key())
|
|
37
|
+
return f.encrypt(plaintext.encode()).decode()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def decrypt(ciphertext: str) -> str:
|
|
41
|
+
"""Decrypt a Fernet token string back to plaintext."""
|
|
42
|
+
f = Fernet(_get_fernet_key())
|
|
43
|
+
return f.decrypt(ciphertext.encode()).decode()
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: flyfun-common
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Shared user management and auth for flyfun services
|
|
5
|
+
Project-URL: Homepage, https://flyfun.aero
|
|
6
|
+
Project-URL: Repository, https://github.com/roznet/flyfun-common
|
|
7
|
+
Author-email: Brice Rosenzweig <brice@ro-z.net>
|
|
8
|
+
License-Expression: MIT
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Keywords: auth,fastapi,flyfun,oauth
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Framework :: FastAPI
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
19
|
+
Requires-Python: >=3.11
|
|
20
|
+
Requires-Dist: authlib>=1.3
|
|
21
|
+
Requires-Dist: cryptography>=42.0
|
|
22
|
+
Requires-Dist: fastapi>=0.109
|
|
23
|
+
Requires-Dist: httpx
|
|
24
|
+
Requires-Dist: pyjwt>=2.8
|
|
25
|
+
Requires-Dist: sqlalchemy>=2.0
|
|
26
|
+
Provides-Extra: dev
|
|
27
|
+
Requires-Dist: httpx; extra == 'dev'
|
|
28
|
+
Requires-Dist: pytest; extra == 'dev'
|
|
29
|
+
Requires-Dist: pytest-asyncio; extra == 'dev'
|
|
30
|
+
Description-Content-Type: text/markdown
|
|
31
|
+
|
|
32
|
+
# flyfun-common
|
|
33
|
+
|
|
34
|
+
Shared user management and authentication library for [flyfun](https://flyfun.aero) services.
|
|
35
|
+
|
|
36
|
+
Provides OAuth login (Google, Apple), JWT session management, user database models, and API token administration — all as reusable FastAPI components.
|
|
37
|
+
|
|
38
|
+
## Installation
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
pip install flyfun-common
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
## Usage
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
from fastapi import FastAPI
|
|
48
|
+
from flyfun_common.auth import create_auth_router
|
|
49
|
+
from flyfun_common.db import init_db
|
|
50
|
+
|
|
51
|
+
app = FastAPI()
|
|
52
|
+
|
|
53
|
+
# Initialize database
|
|
54
|
+
init_db()
|
|
55
|
+
|
|
56
|
+
# Mount the auth router
|
|
57
|
+
app.include_router(create_auth_router())
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
## Configuration
|
|
61
|
+
|
|
62
|
+
All configuration is via environment variables:
|
|
63
|
+
|
|
64
|
+
| Variable | Required | Description |
|
|
65
|
+
|----------|----------|-------------|
|
|
66
|
+
| `JWT_SECRET` | Production | Secret key for signing JWT tokens |
|
|
67
|
+
| `DATABASE_URL` | No | SQLAlchemy database URL (defaults to local SQLite) |
|
|
68
|
+
| `ENVIRONMENT` | No | `production` or `development` (default) |
|
|
69
|
+
| `COOKIE_DOMAIN` | No | Cookie domain for cross-subdomain SSO |
|
|
70
|
+
| `GOOGLE_CLIENT_ID` | No | Google OAuth client ID |
|
|
71
|
+
| `GOOGLE_CLIENT_SECRET` | No | Google OAuth client secret |
|
|
72
|
+
| `APPLE_CLIENT_ID` | No | Apple Sign In service ID |
|
|
73
|
+
| `APPLE_TEAM_ID` | No | Apple Developer Team ID |
|
|
74
|
+
| `APPLE_KEY_ID` | No | Apple Sign In key ID |
|
|
75
|
+
| `APPLE_PRIVATE_KEY` | No | Apple Sign In private key (PEM) |
|
|
76
|
+
| `CREDENTIAL_ENCRYPTION_KEY` | Production | Fernet key for encrypting stored credentials |
|
|
77
|
+
|
|
78
|
+
## License
|
|
79
|
+
|
|
80
|
+
MIT
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
flyfun_common/__init__.py,sha256=sne0db9Y4FxbtM0BComZy1Of2b1CvvXQKpjt0uK5igw,59
|
|
2
|
+
flyfun_common/admin.py,sha256=FVOF1CYXBSvqHOLYV0InOFF02vbHHESWKJCpljKFv2g,2701
|
|
3
|
+
flyfun_common/costs.py,sha256=ajQG6kSaxYnwn9G_5BTQeY8RNrOzOGXmxpd42_KmKfk,2294
|
|
4
|
+
flyfun_common/credentials.py,sha256=O4dR5d41EJPd8qS0TLjCopK-LBtSA5_75MioNyh-18s,1423
|
|
5
|
+
flyfun_common/encryption.py,sha256=mRMNGBs0qQ9Vl294GqxNZQaIFjx6xT2wf4pQBmguywg,1293
|
|
6
|
+
flyfun_common/auth/__init__.py,sha256=Nyw-7xHs2rURHQa0inOoY9qT23R2s_O4aMyIpECUJso,597
|
|
7
|
+
flyfun_common/auth/config.py,sha256=Sfg_P9WJnYe7Bg-ig8gXE50U3_ie3gsSBxhvKloq7Kg,3886
|
|
8
|
+
flyfun_common/auth/jwt_utils.py,sha256=6bMEnckONLBkOFfVRo4lv1ldlaSfDVI-g7rcEnPTP2k,750
|
|
9
|
+
flyfun_common/auth/router.py,sha256=6VwFhPB9oC2qQo3ZGRbzBF0mk2Ae4W5Tk7IYRCQXYvE,13228
|
|
10
|
+
flyfun_common/db/__init__.py,sha256=UdNzCAzH1bsdcYp8zO9qq_rntksyt8G3LpKvnJMw8m0,925
|
|
11
|
+
flyfun_common/db/deps.py,sha256=8zsPvFXQOoGstyGii6IExTKbl-NqcFqp3oGVH1ZHWxw,4188
|
|
12
|
+
flyfun_common/db/engine.py,sha256=jZDFiLive9G1oiIl66qPMEGV4I1EbEI1CGXCvRDKBug,3027
|
|
13
|
+
flyfun_common/db/models.py,sha256=vhRFoBV0_Ritqllxhu4ho_Zv6TD_U3ahW1yUCrqDcUU,3227
|
|
14
|
+
flyfun_common-0.1.0.dist-info/METADATA,sha256=hA6OuzoHtwFIBvgCl8WK-IkZDZpphAWVTwscMEjHIVc,2536
|
|
15
|
+
flyfun_common-0.1.0.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
|
|
16
|
+
flyfun_common-0.1.0.dist-info/licenses/LICENSE,sha256=7P9PLNMrF8Yu0-v6OUHgbUghvNPBiw0oqOlY5IezjO8,1078
|
|
17
|
+
flyfun_common-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024-2026 Brice Rosenzweig
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|